aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--appservice/api/query.go4
-rw-r--r--build/gobind-pinecone/monolith.go2
-rw-r--r--build/gobind-yggdrasil/monolith.go1
-rw-r--r--clientapi/clientapi.go4
-rw-r--r--clientapi/routing/createroom.go6
-rw-r--r--clientapi/routing/joinroom.go4
-rw-r--r--clientapi/routing/key_crosssigning.go4
-rw-r--r--clientapi/routing/login.go4
-rw-r--r--clientapi/routing/membership.go18
-rw-r--r--clientapi/routing/password.go4
-rw-r--r--clientapi/routing/peekroom.go6
-rw-r--r--clientapi/routing/profile.go14
-rw-r--r--clientapi/routing/register.go6
-rw-r--r--clientapi/routing/routing.go4
-rw-r--r--clientapi/routing/sendtyping.go4
-rw-r--r--clientapi/routing/threepid.go12
-rw-r--r--clientapi/threepid/invites.go8
-rw-r--r--cmd/create-account/main.go13
-rw-r--r--cmd/dendrite-demo-libp2p/main.go1
-rw-r--r--cmd/dendrite-demo-pinecone/main.go1
-rw-r--r--cmd/dendrite-demo-yggdrasil/main.go1
-rw-r--r--cmd/dendritejs-pinecone/main.go1
-rw-r--r--cmd/dendritejs/main.go1
-rw-r--r--cmd/generate-config/main.go1
-rw-r--r--cmd/goose/main.go24
-rw-r--r--internal/test/config.go1
-rw-r--r--keyserver/internal/internal.go82
-rw-r--r--keyserver/storage/interface.go2
-rw-r--r--keyserver/storage/postgres/device_keys_table.go33
-rw-r--r--keyserver/storage/shared/storage.go4
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go31
-rw-r--r--keyserver/storage/storage_test.go2
-rw-r--r--keyserver/storage/tables/interface.go2
-rw-r--r--setup/base/base.go12
-rw-r--r--setup/config/config_userapi.go6
-rw-r--r--setup/monolith.go4
-rw-r--r--userapi/api/api_logintoken.go7
-rw-r--r--userapi/internal/api.go67
-rw-r--r--userapi/internal/api_logintoken.go8
-rw-r--r--userapi/storage/devices/interface.go52
-rw-r--r--userapi/storage/devices/postgres/storage.go270
-rw-r--r--userapi/storage/devices/sqlite3/storage.go271
-rw-r--r--userapi/storage/devices/storage.go42
-rw-r--r--userapi/storage/devices/storage_wasm.go39
-rw-r--r--userapi/storage/interface.go (renamed from userapi/storage/accounts/interface.go)31
-rw-r--r--userapi/storage/postgres/account_data_table.go (renamed from userapi/storage/accounts/postgres/account_data_table.go)0
-rw-r--r--userapi/storage/postgres/accounts_table.go (renamed from userapi/storage/accounts/postgres/accounts_table.go)0
-rw-r--r--userapi/storage/postgres/deltas/20200929203058_is_active.go (renamed from userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go)0
-rw-r--r--userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go (renamed from userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go)5
-rw-r--r--userapi/storage/postgres/deltas/2022021013023800_add_account_type.go (renamed from userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go)0
-rw-r--r--userapi/storage/postgres/devices_table.go (renamed from userapi/storage/devices/postgres/devices_table.go)3
-rw-r--r--userapi/storage/postgres/key_backup_table.go (renamed from userapi/storage/accounts/postgres/key_backup_table.go)0
-rw-r--r--userapi/storage/postgres/key_backup_version_table.go (renamed from userapi/storage/accounts/postgres/key_backup_version_table.go)0
-rw-r--r--userapi/storage/postgres/logintoken_table.go (renamed from userapi/storage/devices/postgres/logintoken_table.go)3
-rw-r--r--userapi/storage/postgres/openid_table.go (renamed from userapi/storage/accounts/postgres/openid_table.go)0
-rw-r--r--userapi/storage/postgres/profile_table.go (renamed from userapi/storage/accounts/postgres/profile_table.go)0
-rw-r--r--userapi/storage/postgres/storage.go (renamed from userapi/storage/accounts/postgres/storage.go)216
-rw-r--r--userapi/storage/postgres/threepid_table.go (renamed from userapi/storage/accounts/postgres/threepid_table.go)0
-rw-r--r--userapi/storage/sqlite3/account_data_table.go (renamed from userapi/storage/accounts/sqlite3/account_data_table.go)0
-rw-r--r--userapi/storage/sqlite3/accounts_table.go (renamed from userapi/storage/accounts/sqlite3/accounts_table.go)0
-rw-r--r--userapi/storage/sqlite3/constraint_wasm.go (renamed from userapi/storage/accounts/sqlite3/constraint_wasm.go)0
-rw-r--r--userapi/storage/sqlite3/deltas/20200929203058_is_active.go (renamed from userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go)0
-rw-r--r--userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go (renamed from userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go)5
-rw-r--r--userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go (renamed from userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go)0
-rw-r--r--userapi/storage/sqlite3/devices_table.go (renamed from userapi/storage/devices/sqlite3/devices_table.go)3
-rw-r--r--userapi/storage/sqlite3/key_backup_table.go (renamed from userapi/storage/accounts/sqlite3/key_backup_table.go)0
-rw-r--r--userapi/storage/sqlite3/key_backup_version_table.go (renamed from userapi/storage/accounts/sqlite3/key_backup_version_table.go)0
-rw-r--r--userapi/storage/sqlite3/logintoken_table.go (renamed from userapi/storage/devices/sqlite3/logintoken_table.go)3
-rw-r--r--userapi/storage/sqlite3/openid_table.go (renamed from userapi/storage/accounts/sqlite3/openid_table.go)0
-rw-r--r--userapi/storage/sqlite3/profile_table.go (renamed from userapi/storage/accounts/sqlite3/profile_table.go)0
-rw-r--r--userapi/storage/sqlite3/storage.go (renamed from userapi/storage/accounts/sqlite3/storage.go)216
-rw-r--r--userapi/storage/sqlite3/threepid_table.go (renamed from userapi/storage/accounts/sqlite3/threepid_table.go)0
-rw-r--r--userapi/storage/storage.go (renamed from userapi/storage/accounts/storage.go)13
-rw-r--r--userapi/storage/storage_wasm.go (renamed from userapi/storage/accounts/storage_wasm.go)8
-rw-r--r--userapi/userapi.go22
-rw-r--r--userapi/userapi_test.go15
76 files changed, 727 insertions, 899 deletions
diff --git a/appservice/api/query.go b/appservice/api/query.go
index cd74d866..e53ad425 100644
--- a/appservice/api/query.go
+++ b/appservice/api/query.go
@@ -23,7 +23,7 @@ import (
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -85,7 +85,7 @@ func RetrieveUserProfile(
ctx context.Context,
userID string,
asAPI AppServiceQueryAPI,
- accountDB accounts.Database,
+ accountDB userdb.Database,
) (*authtypes.Profile, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go
index 211b8d65..acf4406c 100644
--- a/build/gobind-pinecone/monolith.go
+++ b/build/gobind-pinecone/monolith.go
@@ -283,8 +283,6 @@ func (m *DendriteMonolith) Start() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix))
- cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix))
- cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix))
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix))
diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go
index 3d9ba8aa..8b9c88f2 100644
--- a/build/gobind-yggdrasil/monolith.go
+++ b/build/gobind-yggdrasil/monolith.go
@@ -88,7 +88,6 @@ func (m *DendriteMonolith) Start() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory))
- cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory))
diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go
index d678ada9..a65f3b70 100644
--- a/clientapi/clientapi.go
+++ b/clientapi/clientapi.go
@@ -28,7 +28,7 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -37,7 +37,7 @@ func AddPublicRoutes(
router *mux.Router,
synapseAdminRouter *mux.Router,
cfg *config.ClientAPI,
- accountsDB accounts.Database,
+ accountsDB userdb.Database,
federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI,
eduInputAPI eduServerAPI.EDUServerInputAPI,
diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go
index e89d8ff2..80ac2293 100644
--- a/clientapi/routing/createroom.go
+++ b/clientapi/routing/createroom.go
@@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
log "github.com/sirupsen/logrus"
@@ -137,7 +137,7 @@ type fledglingEvent struct {
func CreateRoom(
req *http.Request, device *api.Device,
cfg *config.ClientAPI,
- accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
+ accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we
@@ -151,7 +151,7 @@ func CreateRoom(
func createRoom(
req *http.Request, device *api.Device,
cfg *config.ClientAPI, roomID string,
- accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
+ accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
logger := util.GetLogger(req.Context())
diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go
index 578aaec5..d30a87a5 100644
--- a/clientapi/routing/joinroom.go
+++ b/clientapi/routing/joinroom.go
@@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@@ -32,7 +32,7 @@ func JoinRoomByIDOrAlias(
req *http.Request,
device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI,
- accountDB accounts.Database,
+ accountDB userdb.Database,
roomIDOrAlias string,
) util.JSONResponse {
// Prepare to ask the roomserver to perform the room join.
diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go
index 7b9d8acd..7ecab9d4 100644
--- a/clientapi/routing/key_crosssigning.go
+++ b/clientapi/routing/key_crosssigning.go
@@ -24,7 +24,7 @@ import (
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util"
)
@@ -36,7 +36,7 @@ type crossSigningRequest struct {
func UploadCrossSigningDeviceKeys(
req *http.Request, userInteractiveAuth *auth.UserInteractive,
keyserverAPI api.KeyInternalAPI, device *userapi.Device,
- accountDB accounts.Database, cfg *config.ClientAPI,
+ accountDB userdb.Database, cfg *config.ClientAPI,
) util.JSONResponse {
uploadReq := &crossSigningRequest{}
uploadRes := &api.PerformUploadDeviceKeysResponse{}
diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go
index b48b9e93..ec5c998b 100644
--- a/clientapi/routing/login.go
+++ b/clientapi/routing/login.go
@@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@@ -54,7 +54,7 @@ func passwordLogin() flows {
// Login implements GET and POST /login
func Login(
- req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
+ req *http.Request, accountDB userdb.Database, userAPI userapi.UserInternalAPI,
cfg *config.ClientAPI,
) util.JSONResponse {
if req.Method == http.MethodGet {
diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go
index 58f18760..11223924 100644
--- a/clientapi/routing/membership.go
+++ b/clientapi/routing/membership.go
@@ -30,7 +30,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@@ -39,7 +39,7 @@ import (
var errMissingUserID = errors.New("'user_id' must be supplied")
func SendBan(
- req *http.Request, accountDB accounts.Database, device *userapi.Device,
+ req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
@@ -81,7 +81,7 @@ func SendBan(
return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI)
}
-func sendMembership(ctx context.Context, accountDB accounts.Database, device *userapi.Device,
+func sendMembership(ctx context.Context, accountDB userdb.Database, device *userapi.Device,
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
roomVer gomatrixserverlib.RoomVersion,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse {
@@ -125,7 +125,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us
}
func SendKick(
- req *http.Request, accountDB accounts.Database, device *userapi.Device,
+ req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
@@ -165,7 +165,7 @@ func SendKick(
}
func SendUnban(
- req *http.Request, accountDB accounts.Database, device *userapi.Device,
+ req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
@@ -200,7 +200,7 @@ func SendUnban(
}
func SendInvite(
- req *http.Request, accountDB accounts.Database, device *userapi.Device,
+ req *http.Request, accountDB userdb.Database, device *userapi.Device,
roomID string, cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
@@ -271,7 +271,7 @@ func SendInvite(
func buildMembershipEvent(
ctx context.Context,
- targetUserID, reason string, accountDB accounts.Database,
+ targetUserID, reason string, accountDB userdb.Database,
device *userapi.Device,
membership, roomID string, isDirect bool,
cfg *config.ClientAPI, evTime time.Time,
@@ -312,7 +312,7 @@ func loadProfile(
ctx context.Context,
userID string,
cfg *config.ClientAPI,
- accountDB accounts.Database,
+ accountDB userdb.Database,
asAPI appserviceAPI.AppServiceQueryAPI,
) (*authtypes.Profile, error) {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
@@ -366,7 +366,7 @@ func checkAndProcessThreepid(
body *threepid.MembershipRequest,
cfg *config.ClientAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
- accountDB accounts.Database,
+ accountDB userdb.Database,
roomID string,
evTime time.Time,
) (inviteStored bool, errRes *util.JSONResponse) {
diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go
index b2442443..49951019 100644
--- a/clientapi/routing/password.go
+++ b/clientapi/routing/password.go
@@ -9,7 +9,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@@ -29,7 +29,7 @@ type newPasswordAuth struct {
func Password(
req *http.Request,
userAPI api.UserInternalAPI,
- accountDB accounts.Database,
+ accountDB userdb.Database,
device *api.Device,
cfg *config.ClientAPI,
) util.JSONResponse {
diff --git a/clientapi/routing/peekroom.go b/clientapi/routing/peekroom.go
index 26aa64ce..8f89e97f 100644
--- a/clientapi/routing/peekroom.go
+++ b/clientapi/routing/peekroom.go
@@ -19,7 +19,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@@ -28,7 +28,7 @@ func PeekRoomByIDOrAlias(
req *http.Request,
device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI,
- accountDB accounts.Database,
+ accountDB userdb.Database,
roomIDOrAlias string,
) util.JSONResponse {
// if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to
@@ -82,7 +82,7 @@ func UnpeekRoomByID(
req *http.Request,
device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI,
- accountDB accounts.Database,
+ accountDB userdb.Database,
roomID string,
) util.JSONResponse {
unpeekReq := roomserverAPI.PerformUnpeekRequest{
diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go
index 017facd2..717cbda7 100644
--- a/clientapi/routing/profile.go
+++ b/clientapi/routing/profile.go
@@ -27,7 +27,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrix"
@@ -36,7 +36,7 @@ import (
// GetProfile implements GET /profile/{userID}
func GetProfile(
- req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
+ req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string,
asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
@@ -65,7 +65,7 @@ func GetProfile(
// GetAvatarURL implements GET /profile/{userID}/avatar_url
func GetAvatarURL(
- req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
+ req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) util.JSONResponse {
@@ -92,7 +92,7 @@ func GetAvatarURL(
// SetAvatarURL implements PUT /profile/{userID}/avatar_url
func SetAvatarURL(
- req *http.Request, accountDB accounts.Database,
+ req *http.Request, accountDB userdb.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse {
if userID != device.UserID {
@@ -182,7 +182,7 @@ func SetAvatarURL(
// GetDisplayName implements GET /profile/{userID}/displayname
func GetDisplayName(
- req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
+ req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) util.JSONResponse {
@@ -209,7 +209,7 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName(
- req *http.Request, accountDB accounts.Database,
+ req *http.Request, accountDB userdb.Database,
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse {
if userID != device.UserID {
@@ -302,7 +302,7 @@ func SetDisplayName(
// Returns an error when something goes wrong or specifically
// eventutil.ErrProfileNoExists when the profile doesn't exist.
func getProfile(
- ctx context.Context, accountDB accounts.Database, cfg *config.ClientAPI,
+ ctx context.Context, accountDB userdb.Database, cfg *config.ClientAPI,
userID string,
asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go
index f73cc662..d00d9886 100644
--- a/clientapi/routing/register.go
+++ b/clientapi/routing/register.go
@@ -44,7 +44,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/userutil"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
)
var (
@@ -448,7 +448,7 @@ func validateApplicationService(
func Register(
req *http.Request,
userAPI userapi.UserInternalAPI,
- accountDB accounts.Database,
+ accountDB userdb.Database,
cfg *config.ClientAPI,
) util.JSONResponse {
var r registerRequest
@@ -899,7 +899,7 @@ type availableResponse struct {
func RegisterAvailable(
req *http.Request,
cfg *config.ClientAPI,
- accountDB accounts.Database,
+ accountDB userdb.Database,
) util.JSONResponse {
username := req.URL.Query().Get("username")
diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go
index da2ccf2f..63dcaa41 100644
--- a/clientapi/routing/routing.go
+++ b/clientapi/routing/routing.go
@@ -34,7 +34,7 @@ import (
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@@ -51,7 +51,7 @@ func Setup(
eduAPI eduServerAPI.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
- accountDB accounts.Database,
+ accountDB userdb.Database,
userAPI userapi.UserInternalAPI,
federation *gomatrixserverlib.FederationClient,
syncProducer *producers.SyncAPIProducer,
diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go
index 3abf3db2..fd214b34 100644
--- a/clientapi/routing/sendtyping.go
+++ b/clientapi/routing/sendtyping.go
@@ -20,7 +20,7 @@ import (
"github.com/matrix-org/dendrite/eduserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/util"
)
@@ -33,7 +33,7 @@ type typingContentJSON struct {
// sends the typing events to client API typingProducer
func SendTyping(
req *http.Request, device *userapi.Device, roomID string,
- userID string, accountDB accounts.Database,
+ userID string, accountDB userdb.Database,
eduAPI api.EDUServerInputAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
) util.JSONResponse {
diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go
index f4d23379..d89b6295 100644
--- a/clientapi/routing/threepid.go
+++ b/clientapi/routing/threepid.go
@@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/threepid"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@@ -40,7 +40,7 @@ type threePIDsResponse struct {
// RequestEmailToken implements:
// POST /account/3pid/email/requestToken
// POST /register/email/requestToken
-func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI) util.JSONResponse {
+func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI) util.JSONResponse {
var body threepid.EmailAssociationRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr
@@ -61,7 +61,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
Code: http.StatusBadRequest,
JSON: jsonerror.MatrixError{
ErrCode: "M_THREEPID_IN_USE",
- Err: accounts.Err3PIDInUse.Error(),
+ Err: userdb.Err3PIDInUse.Error(),
},
}
}
@@ -85,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
// CheckAndSave3PIDAssociation implements POST /account/3pid
func CheckAndSave3PIDAssociation(
- req *http.Request, accountDB accounts.Database, device *api.Device,
+ req *http.Request, accountDB userdb.Database, device *api.Device,
cfg *config.ClientAPI,
) util.JSONResponse {
var body threepid.EmailAssociationCheckRequest
@@ -149,7 +149,7 @@ func CheckAndSave3PIDAssociation(
// GetAssociated3PIDs implements GET /account/3pid
func GetAssociated3PIDs(
- req *http.Request, accountDB accounts.Database, device *api.Device,
+ req *http.Request, accountDB userdb.Database, device *api.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
@@ -170,7 +170,7 @@ func GetAssociated3PIDs(
}
// Forget3PID implements POST /account/3pid/delete
-func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse {
+func Forget3PID(req *http.Request, accountDB userdb.Database) util.JSONResponse {
var body authtypes.ThreePID
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr
diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go
index db62ce06..9d9a2ba7 100644
--- a/clientapi/threepid/invites.go
+++ b/clientapi/threepid/invites.go
@@ -29,7 +29,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -87,7 +87,7 @@ var (
func CheckAndProcessInvite(
ctx context.Context,
device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI,
- rsAPI api.RoomserverInternalAPI, db accounts.Database,
+ rsAPI api.RoomserverInternalAPI, db userdb.Database,
roomID string,
evTime time.Time,
) (inviteStoredOnIDServer bool, err error) {
@@ -137,7 +137,7 @@ func CheckAndProcessInvite(
// Returns an error if a check or a request failed.
func queryIDServer(
ctx context.Context,
- db accounts.Database, cfg *config.ClientAPI, device *userapi.Device,
+ db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
if err = isTrusted(body.IDServer, cfg); err != nil {
@@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
// Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite(
ctx context.Context,
- db accounts.Database, cfg *config.ClientAPI, device *userapi.Device,
+ db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) {
// Retrieve the sender's profile to get their display name
diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go
index d9202eb0..3003896c 100644
--- a/cmd/create-account/main.go
+++ b/cmd/create-account/main.go
@@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
)
const usage = `Usage: %s
@@ -77,9 +77,14 @@ func main() {
pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin)
- accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{
- ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
- }, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS)
+ accountDB, err := userdb.NewDatabase(
+ &config.DatabaseOptions{
+ ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
+ },
+ cfg.Global.ServerName, bcrypt.DefaultCost,
+ cfg.UserAPI.OpenIDTokenLifetimeMS,
+ api.DefaultLoginTokenLifetime,
+ )
if err != nil {
logrus.Fatalln("Failed to connect to the database:", err.Error())
}
diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go
index 7cbd0b6d..78536901 100644
--- a/cmd/dendrite-demo-libp2p/main.go
+++ b/cmd/dendrite-demo-libp2p/main.go
@@ -126,7 +126,6 @@ func main() {
cfg.FederationAPI.FederationMaxRetries = 6
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
- cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go
index a897dcd1..5810a7f1 100644
--- a/cmd/dendrite-demo-pinecone/main.go
+++ b/cmd/dendrite-demo-pinecone/main.go
@@ -160,7 +160,6 @@ func main() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
- cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go
index 52e69ee5..49e096bd 100644
--- a/cmd/dendrite-demo-yggdrasil/main.go
+++ b/cmd/dendrite-demo-yggdrasil/main.go
@@ -79,7 +79,6 @@ func main() {
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
- cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go
index 62eea78f..664f644f 100644
--- a/cmd/dendritejs-pinecone/main.go
+++ b/cmd/dendritejs-pinecone/main.go
@@ -164,7 +164,6 @@ func startup() {
cfg.Defaults(true)
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
- cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"
diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go
index 59de07cd..0ea41b4c 100644
--- a/cmd/dendritejs/main.go
+++ b/cmd/dendritejs/main.go
@@ -167,7 +167,6 @@ func main() {
cfg.Defaults(true)
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
- cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"
diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go
index f87665fb..ba5a87a7 100644
--- a/cmd/generate-config/main.go
+++ b/cmd/generate-config/main.go
@@ -32,7 +32,6 @@ func main() {
cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI)
cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI)
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI)
- cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(*dbURI)
}
cfg.Global.TrustedIDServers = []string{
"matrix.org",
diff --git a/cmd/goose/main.go b/cmd/goose/main.go
index 8ed5cbd9..31a5b005 100644
--- a/cmd/goose/main.go
+++ b/cmd/goose/main.go
@@ -8,12 +8,11 @@ import (
"log"
"os"
- pgaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
- slaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
- pgdevices "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
- sldevices "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
"github.com/pressly/goose"
+ pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
+ slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
+
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
@@ -26,8 +25,7 @@ const (
RoomServer = "roomserver"
SigningKeyServer = "signingkeyserver"
SyncAPI = "syncapi"
- UserAPIAccounts = "userapi_accounts"
- UserAPIDevices = "userapi_devices"
+ UserAPI = "userapi"
)
var (
@@ -35,7 +33,7 @@ var (
flags = flag.NewFlagSet("goose", flag.ExitOnError)
component = flags.String("component", "", "dendrite component name")
knownDBs = []string{
- AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPIAccounts, UserAPIDevices,
+ AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI,
}
)
@@ -143,18 +141,14 @@ Commands:
func loadSQLiteDeltas(component string) {
switch component {
- case UserAPIAccounts:
- slaccounts.LoadFromGoose()
- case UserAPIDevices:
- sldevices.LoadFromGoose()
+ case UserAPI:
+ slusers.LoadFromGoose()
}
}
func loadPostgresDeltas(component string) {
switch component {
- case UserAPIAccounts:
- pgaccounts.LoadFromGoose()
- case UserAPIDevices:
- pgdevices.LoadFromGoose()
+ case UserAPI:
+ pgusers.LoadFromGoose()
}
}
diff --git a/internal/test/config.go b/internal/test/config.go
index 4fb6a946..0372fb9c 100644
--- a/internal/test/config.go
+++ b/internal/test/config.go
@@ -95,7 +95,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con
cfg.RoomServer.Database.ConnectionString = config.DataSource(database)
cfg.SyncAPI.Database.ConnectionString = config.DataSource(database)
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database)
- cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(database)
cfg.AppServiceAPI.InternalAPI.Listen = assignAddress()
cfg.EDUServer.InternalAPI.Listen = assignAddress()
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index ffbcac94..1c6b0677 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -198,7 +198,7 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne
}
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
- msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil)
+ msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
@@ -244,7 +244,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
domain := string(serverName)
// query local devices
if serverName == a.ThisServer {
- deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
+ deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query local device keys: %s", err),
@@ -525,7 +525,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
) error {
- keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
+ keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
if err != nil {
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
@@ -554,10 +554,60 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
}
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
+ // get a list of devices from the user API that actually exist, as
+ // we won't store keys for devices that don't exist
+ uapidevices := &userapi.QueryDevicesResponse{}
+ if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
+ res.Error = &api.KeyError{
+ Err: err.Error(),
+ }
+ return
+ }
+ if !uapidevices.UserExists {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("user %q does not exist", req.UserID),
+ }
+ return
+ }
+ existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
+ for _, key := range uapidevices.Devices {
+ existingDeviceMap[key.ID] = struct{}{}
+ }
+
+ // Get all of the user existing device keys so we can check for changes.
+ existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
+ }
+ return
+ }
+
+ // Work out whether we have device keys in the keyserver for devices that
+ // no longer exist in the user API. This is mostly an exercise to ensure
+ // that we keep some integrity between the two.
+ var toClean []gomatrixserverlib.KeyID
+ for _, k := range existingKeys {
+ if _, ok := existingDeviceMap[k.DeviceID]; !ok {
+ toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
+ }
+ }
+
+ if len(toClean) > 0 {
+ if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to clean device keys: %s", err.Error()),
+ }
+ return
+ }
+ logrus.WithField("user_id", req.UserID).Infof("Cleaned up %d stale keyserver device key entries", len(toClean))
+ }
+
var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key
for _, key := range req.DeviceKeys {
- _, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
+ var serverName gomatrixserverlib.ServerName
+ _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
if err != nil {
continue // ignore invalid users
}
@@ -568,6 +618,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
keysToStore = append(keysToStore, key.WithStreamID(0))
continue // deleted keys don't need sanity checking
}
+ // check that the device in question actually exists in the user
+ // API before we try and store a key for it
+ if _, ok := existingDeviceMap[key.DeviceID]; !ok {
+ continue
+ }
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
@@ -583,29 +638,12 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
})
}
- // get existing device keys so we can check for changes
- existingKeys := make([]api.DeviceMessage, len(keysToStore))
- for i := range keysToStore {
- existingKeys[i] = api.DeviceMessage{
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- UserID: keysToStore[i].UserID,
- DeviceID: keysToStore[i].DeviceID,
- },
- }
- }
- if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
- }
- return
- }
if req.OnlyDisplayNameUpdates {
// add the display name field from keysToStore into existingKeys
keysToStore = appendDisplayNames(existingKeys, keysToStore)
}
// store the device keys and emit changes
- err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
+ err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go
index 0110860e..4dffe695 100644
--- a/keyserver/storage/interface.go
+++ b/keyserver/storage/interface.go
@@ -53,7 +53,7 @@ type Database interface {
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
- DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
+ DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go
index 5ae0da96..628301cf 100644
--- a/keyserver/storage/postgres/device_keys_table.go
+++ b/keyserver/storage/postgres/device_keys_table.go
@@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" +
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
+const selectBatchDeviceKeysWithEmptiesSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
+
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@@ -69,14 +72,15 @@ const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
- db *sql.DB
- upsertDeviceKeysStmt *sql.Stmt
- selectDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysStmt *sql.Stmt
- selectMaxStreamForUserStmt *sql.Stmt
- countStreamIDsForUserStmt *sql.Stmt
- deleteDeviceKeysStmt *sql.Stmt
- deleteAllDeviceKeysStmt *sql.Stmt
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
+ countStreamIDsForUserStmt *sql.Stmt
+ deleteDeviceKeysStmt *sql.Stmt
+ deleteAllDeviceKeysStmt *sql.Stmt
}
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
+ if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
+ return nil, err
+ }
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
@@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
return err
}
-func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
- rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
+ var stmt *sql.Stmt
+ if includeEmpty {
+ stmt = s.selectBatchDeviceKeysWithEmptiesStmt
+ } else {
+ stmt = s.selectBatchDeviceKeysStmt
+ }
+ rows, err := stmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go
index 5914d28e..deee76eb 100644
--- a/keyserver/storage/shared/storage.go
+++ b/keyserver/storage/shared/storage.go
@@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
})
}
-func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
- return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
+func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
+ return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
}
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
index fa1c930d..b461424c 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/keyserver/storage/sqlite3/device_keys_table.go
@@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" +
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
+const selectBatchDeviceKeysWithEmptiesSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
+
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@@ -65,13 +68,14 @@ const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
- db *sql.DB
- upsertDeviceKeysStmt *sql.Stmt
- selectDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysStmt *sql.Stmt
- selectMaxStreamForUserStmt *sql.Stmt
- deleteDeviceKeysStmt *sql.Stmt
- deleteAllDeviceKeysStmt *sql.Stmt
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
+ deleteDeviceKeysStmt *sql.Stmt
+ deleteAllDeviceKeysStmt *sql.Stmt
}
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
+ if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
+ return nil, err
+ }
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
@@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
return err
}
-func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {
deviceIDMap[d] = true
}
- rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
+ var stmt *sql.Stmt
+ if includeEmpty {
+ stmt = s.selectBatchDeviceKeysWithEmptiesStmt
+ } else {
+ stmt = s.selectBatchDeviceKeysStmt
+ }
+ rows, err := stmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go
index c4c99d8c..4d513724 100644
--- a/keyserver/storage/storage_test.go
+++ b/keyserver/storage/storage_test.go
@@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
}
// Querying for device keys returns the latest stream IDs
- msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
+ msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go
index e44757e1..ff70a236 100644
--- a/keyserver/storage/tables/interface.go
+++ b/keyserver/storage/tables/interface.go
@@ -38,7 +38,7 @@ type DeviceKeys interface {
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
- SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
+ SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
}
diff --git a/setup/base/base.go b/setup/base/base.go
index 819fe1ad..e3997754 100644
--- a/setup/base/base.go
+++ b/setup/base/base.go
@@ -38,7 +38,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/process"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/gorilla/mux"
@@ -273,8 +273,14 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
// CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component.
-func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
- db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS)
+func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
+ db, err := userdb.NewDatabase(
+ &b.Cfg.UserAPI.AccountDatabase,
+ b.Cfg.Global.ServerName,
+ b.Cfg.UserAPI.BCryptCost,
+ b.Cfg.UserAPI.OpenIDTokenLifetimeMS,
+ userapi.DefaultLoginTokenLifetime,
+ )
if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db")
}
diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go
index b2cde2e9..1cb5eba1 100644
--- a/setup/config/config_userapi.go
+++ b/setup/config/config_userapi.go
@@ -16,9 +16,6 @@ type UserAPI struct {
// The Account database stores the login details and account information
// for local users. It is accessed by the UserAPI.
AccountDatabase DatabaseOptions `yaml:"account_database"`
- // The Device database stores session information for the devices of logged
- // in local users. It is accessed by the UserAPI.
- DeviceDatabase DatabaseOptions `yaml:"device_database"`
}
const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes
@@ -27,10 +24,8 @@ func (c *UserAPI) Defaults(generate bool) {
c.InternalAPI.Listen = "http://localhost:7781"
c.InternalAPI.Connect = "http://localhost:7781"
c.AccountDatabase.Defaults(10)
- c.DeviceDatabase.Defaults(10)
if generate {
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
- c.DeviceDatabase.ConnectionString = "file:userapi_devices.db"
}
c.BCryptCost = bcrypt.DefaultCost
c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS
@@ -40,6 +35,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen))
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
- checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString))
checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS)
}
diff --git a/setup/monolith.go b/setup/monolith.go
index e6c95522..61125e4a 100644
--- a/setup/monolith.go
+++ b/setup/monolith.go
@@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -38,7 +38,7 @@ import (
// all components of Dendrite, for use in monolith mode.
type Monolith struct {
Config *config.Dendrite
- AccountDB accounts.Database
+ AccountDB userdb.Database
KeyRing *gomatrixserverlib.KeyRing
Client *gomatrixserverlib.Client
FedClient *gomatrixserverlib.FederationClient
diff --git a/userapi/api/api_logintoken.go b/userapi/api/api_logintoken.go
index f3aa037e..e2207bb5 100644
--- a/userapi/api/api_logintoken.go
+++ b/userapi/api/api_logintoken.go
@@ -19,6 +19,13 @@ import (
"time"
)
+// DefaultLoginTokenLifetime determines how old a valid token may be.
+//
+// NOTSPEC: The current spec says "SHOULD be limited to around five
+// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
+// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
+const DefaultLoginTokenLifetime = 2 * time.Minute
+
type LoginTokenInternalAPI interface {
// PerformLoginTokenCreation creates a new login token and associates it with the provided data.
PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index f96d4804..f54cc613 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -31,13 +31,11 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
- "github.com/matrix-org/dendrite/userapi/storage/devices"
+ "github.com/matrix-org/dendrite/userapi/storage"
)
type UserInternalAPI struct {
- AccountDB accounts.Database
- DeviceDB devices.Database
+ DB storage.Database
ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
@@ -55,11 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" {
return fmt.Errorf("data type must not be empty")
}
- return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
+ return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
- acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
+ acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict {
@@ -89,7 +87,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil
}
- if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
+ if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err
}
@@ -99,7 +97,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
}
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
- if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
+ if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err
}
res.PasswordUpdated = true
@@ -112,7 +110,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
"device_id": req.DeviceID,
"display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation")
- dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
+ dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil {
return err
}
@@ -137,12 +135,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
- devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
+ devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
} else {
- err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)
+ err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
}
if err != nil {
return err
@@ -196,7 +194,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
- if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
+ if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
}
return nil
@@ -208,7 +206,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err
}
- dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID)
+ dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
if err == sql.ErrNoRows {
res.DeviceExists = false
return nil
@@ -223,7 +221,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil
}
- err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
+ err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err
@@ -261,7 +259,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if domain != a.ServerName {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
}
- prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local)
+ prof, err := a.DB.GetProfileByLocalpart(ctx, local)
if err != nil {
if err == sql.ErrNoRows {
return nil
@@ -275,7 +273,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
}
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
- profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit)
+ profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit)
if err != nil {
return err
}
@@ -284,7 +282,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer
}
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
- devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
+ devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil {
return err
}
@@ -312,10 +310,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if domain != a.ServerName {
return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName)
}
- devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local)
+ devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
if err != nil {
return err
}
+ res.UserExists = true
res.Devices = devs
return nil
}
@@ -330,7 +329,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
if req.DataType != "" {
var data json.RawMessage
- data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
+ data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
if err != nil {
return err
}
@@ -348,7 +347,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
return nil
}
- global, rooms, err := a.AccountDB.GetAccountData(ctx, local)
+ global, rooms, err := a.DB.GetAccountData(ctx, local)
if err != nil {
return err
}
@@ -367,7 +366,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil
}
- device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken)
+ device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken)
if err != nil {
if err == sql.ErrNoRows {
return nil
@@ -378,7 +377,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
if err != nil {
return err
}
- acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart)
+ acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
if err != nil {
return err
}
@@ -419,7 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered
- account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart)
+ account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
// Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@@ -437,7 +436,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
- err := a.AccountDB.DeactivateAccount(ctx, req.Localpart)
+ err := a.DB.DeactivateAccount(ctx, req.Localpart)
res.AccountDeactivated = err == nil
return err
}
@@ -446,7 +445,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24)
- exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID)
+ exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID)
res.Token = api.OpenIDToken{
Token: token,
@@ -459,7 +458,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
- openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token)
+ openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token)
if err != nil {
return err
}
@@ -481,7 +480,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
return nil
}
- exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version)
+ exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
if err != nil {
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
}
@@ -494,7 +493,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
// Create metadata
if req.Version == "" {
- version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
+ version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to create backup: %s", err)
}
@@ -507,7 +506,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
// Update metadata
if len(req.Keys.Rooms) == 0 {
- err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
+ err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to update backup: %s", err)
}
@@ -528,7 +527,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
// you can only upload keys for the CURRENT version
- version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "")
+ version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
if err != nil {
res.Error = fmt.Sprintf("failed to query version: %s", err)
return
@@ -556,7 +555,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
})
}
}
- count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
+ count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
if err != nil {
res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
return
@@ -566,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
}
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
- version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
+ version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version
if err != nil {
if err == sql.ErrNoRows {
@@ -582,14 +581,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
res.Exists = !deleted
if !req.ReturnKeys {
- res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID)
+ res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err)
}
return
}
- result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
+ result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err)
return
diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go
index 86ffc58f..f1bf391e 100644
--- a/userapi/internal/api_logintoken.go
+++ b/userapi/internal/api_logintoken.go
@@ -34,7 +34,7 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
if domain != a.ServerName {
return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName)
}
- tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data)
+ tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data)
if err != nil {
return err
}
@@ -45,13 +45,13 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
// PerformLoginTokenDeletion ensures the token doesn't exist.
func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error {
util.GetLogger(ctx).Info("PerformLoginTokenDeletion")
- return a.DeviceDB.RemoveLoginToken(ctx, req.Token)
+ return a.DB.RemoveLoginToken(ctx, req.Token)
}
// QueryLoginToken returns the data associated with a login token. If
// the token is not valid, success is returned, but res.Data == nil.
func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error {
- tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token)
+ tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token)
if err != nil {
res.Data = nil
if err == sql.ErrNoRows {
@@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if domain != a.ServerName {
return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName)
}
- if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil {
+ if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
res.Data = nil
if err == sql.ErrNoRows {
return nil
diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go
deleted file mode 100644
index 8ff91cf1..00000000
--- a/userapi/storage/devices/interface.go
+++ /dev/null
@@ -1,52 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package devices
-
-import (
- "context"
-
- "github.com/matrix-org/dendrite/userapi/api"
-)
-
-type Database interface {
- GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
- GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
- GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
- GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
- // CreateDevice makes a new device associated with the given user ID localpart.
- // If there is already a device with the same device ID for this user, that access token will be revoked
- // and replaced with the given accessToken. If the given accessToken is already in use for another device,
- // an error will be returned.
- // If no device ID is given one is generated.
- // Returns the device on success.
- CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
- UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
- UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
- RemoveDevice(ctx context.Context, deviceID, localpart string) error
- RemoveDevices(ctx context.Context, localpart string, devices []string) error
- // RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
- RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
-
- // CreateLoginToken generates a token, stores and returns it. The lifetime is
- // determined by the loginTokenLifetime given to the Database constructor.
- CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
-
- // RemoveLoginToken removes the named token (and may clean up other expired tokens).
- RemoveLoginToken(ctx context.Context, token string) error
-
- // GetLoginTokenDataByToken returns the data associated with the given token.
- // May return sql.ErrNoRows.
- GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
-}
diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go
deleted file mode 100644
index fd9d513f..00000000
--- a/userapi/storage/devices/postgres/storage.go
+++ /dev/null
@@ -1,270 +0,0 @@
-// Copyright 2017 Vector Creations Ltd
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package postgres
-
-import (
- "context"
- "crypto/rand"
- "database/sql"
- "encoding/base64"
- "time"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-const (
- // The length of generated device IDs
- deviceIDByteLength = 6
- loginTokenByteLength = 32
-)
-
-// Database represents a device database.
-type Database struct {
- db *sql.DB
- devices devicesStatements
- loginTokens loginTokenStatements
- loginTokenLifetime time.Duration
-}
-
-// NewDatabase creates a new device database
-func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
- db, err := sqlutil.Open(dbProperties)
- if err != nil {
- return nil, err
- }
- var d devicesStatements
- var lt loginTokenStatements
-
- // Create tables before executing migrations so we don't fail if the table is missing,
- // and THEN prepare statements so we don't fail due to referencing new columns
- if err = d.execSchema(db); err != nil {
- return nil, err
- }
- if err = lt.execSchema(db); err != nil {
- return nil, err
- }
-
- m := sqlutil.NewMigrations()
- deltas.LoadLastSeenTSIP(m)
- if err = m.RunDeltas(db, dbProperties); err != nil {
- return nil, err
- }
-
- if err = d.prepare(db, serverName); err != nil {
- return nil, err
- }
- if err = lt.prepare(db); err != nil {
- return nil, err
- }
-
- return &Database{db, d, lt, loginTokenLifetime}, nil
-}
-
-// GetDeviceByAccessToken returns the device matching the given access token.
-// Returns sql.ErrNoRows if no matching device was found.
-func (d *Database) GetDeviceByAccessToken(
- ctx context.Context, token string,
-) (*api.Device, error) {
- return d.devices.selectDeviceByToken(ctx, token)
-}
-
-// GetDeviceByID returns the device matching the given ID.
-// Returns sql.ErrNoRows if no matching device was found.
-func (d *Database) GetDeviceByID(
- ctx context.Context, localpart, deviceID string,
-) (*api.Device, error) {
- return d.devices.selectDeviceByID(ctx, localpart, deviceID)
-}
-
-// GetDevicesByLocalpart returns the devices matching the given localpart.
-func (d *Database) GetDevicesByLocalpart(
- ctx context.Context, localpart string,
-) ([]api.Device, error) {
- return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
-}
-
-func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
- return d.devices.selectDevicesByID(ctx, deviceIDs)
-}
-
-// CreateDevice makes a new device associated with the given user ID localpart.
-// If there is already a device with the same device ID for this user, that access token will be revoked
-// and replaced with the given accessToken. If the given accessToken is already in use for another device,
-// an error will be returned.
-// If no device ID is given one is generated.
-// Returns the device on success.
-func (d *Database) CreateDevice(
- ctx context.Context, localpart string, deviceID *string, accessToken string,
- displayName *string, ipAddr, userAgent string,
-) (dev *api.Device, returnErr error) {
- if deviceID != nil {
- returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- var err error
- // Revoke existing tokens for this device
- if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
- return err
- }
-
- dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
- return err
- })
- } else {
- // We generate device IDs in a loop in case its already taken.
- // We cap this at going round 5 times to ensure we don't spin forever
- var newDeviceID string
- for i := 1; i <= 5; i++ {
- newDeviceID, returnErr = generateDeviceID()
- if returnErr != nil {
- return
- }
-
- returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- var err error
- dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
- return err
- })
- if returnErr == nil {
- return
- }
- }
- }
- return
-}
-
-// generateDeviceID creates a new device id. Returns an error if failed to generate
-// random bytes.
-func generateDeviceID() (string, error) {
- b := make([]byte, deviceIDByteLength)
- _, err := rand.Read(b)
- if err != nil {
- return "", err
- }
- // url-safe no padding
- return base64.RawURLEncoding.EncodeToString(b), nil
-}
-
-// UpdateDevice updates the given device with the display name.
-// Returns SQL error if there are problems and nil on success.
-func (d *Database) UpdateDevice(
- ctx context.Context, localpart, deviceID string, displayName *string,
-) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
- })
-}
-
-// RemoveDevice revokes a device by deleting the entry in the database
-// matching with the given device ID and user ID localpart.
-// If the device doesn't exist, it will not return an error
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveDevice(
- ctx context.Context, deviceID, localpart string,
-) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
-}
-
-// RemoveDevices revokes one or more devices by deleting the entry in the database
-// matching with the given device IDs and user ID localpart.
-// If the devices don't exist, it will not return an error
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveDevices(
- ctx context.Context, localpart string, devices []string,
-) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
-}
-
-// RemoveAllDevices revokes devices by deleting the entry in the
-// database matching the given user ID localpart.
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveAllDevices(
- ctx context.Context, localpart, exceptDeviceID string,
-) (devices []api.Device, err error) {
- err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
- if err != nil {
- return err
- }
- if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
- return
-}
-
-// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
-func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
- })
-}
-
-// CreateLoginToken generates a token, stores and returns it. The lifetime is
-// determined by the loginTokenLifetime given to the Database constructor.
-func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
- tok, err := generateLoginToken()
- if err != nil {
- return nil, err
- }
- meta := &api.LoginTokenMetadata{
- Token: tok,
- Expiration: time.Now().Add(d.loginTokenLifetime),
- }
-
- err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- return d.loginTokens.insert(ctx, txn, meta, data)
- })
- if err != nil {
- return nil, err
- }
-
- return meta, nil
-}
-
-func generateLoginToken() (string, error) {
- b := make([]byte, loginTokenByteLength)
- _, err := rand.Read(b)
- if err != nil {
- return "", err
- }
- return base64.RawURLEncoding.EncodeToString(b), nil
-}
-
-// RemoveLoginToken removes the named token (and may clean up other expired tokens).
-func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- return d.loginTokens.deleteByToken(ctx, txn, token)
- })
-}
-
-// GetLoginTokenDataByToken returns the data associated with the given token.
-// May return sql.ErrNoRows.
-func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
- return d.loginTokens.selectByToken(ctx, token)
-}
diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go
deleted file mode 100644
index 6e90413b..00000000
--- a/userapi/storage/devices/sqlite3/storage.go
+++ /dev/null
@@ -1,271 +0,0 @@
-// Copyright 2017 Vector Creations Ltd
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package sqlite3
-
-import (
- "context"
- "crypto/rand"
- "database/sql"
- "encoding/base64"
- "time"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-const (
- // The length of generated device IDs
- deviceIDByteLength = 6
-
- loginTokenByteLength = 32
-)
-
-// Database represents a device database.
-type Database struct {
- db *sql.DB
- writer sqlutil.Writer
- devices devicesStatements
- loginTokens loginTokenStatements
- loginTokenLifetime time.Duration
-}
-
-// NewDatabase creates a new device database
-func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
- db, err := sqlutil.Open(dbProperties)
- if err != nil {
- return nil, err
- }
- writer := sqlutil.NewExclusiveWriter()
- var d devicesStatements
- var lt loginTokenStatements
-
- // Create tables before executing migrations so we don't fail if the table is missing,
- // and THEN prepare statements so we don't fail due to referencing new columns
- if err = d.execSchema(db); err != nil {
- return nil, err
- }
- if err = lt.execSchema(db); err != nil {
- return nil, err
- }
-
- m := sqlutil.NewMigrations()
- deltas.LoadLastSeenTSIP(m)
- if err = m.RunDeltas(db, dbProperties); err != nil {
- return nil, err
- }
- if err = d.prepare(db, writer, serverName); err != nil {
- return nil, err
- }
- if err = lt.prepare(db); err != nil {
- return nil, err
- }
- return &Database{db, writer, d, lt, loginTokenLifetime}, nil
-}
-
-// GetDeviceByAccessToken returns the device matching the given access token.
-// Returns sql.ErrNoRows if no matching device was found.
-func (d *Database) GetDeviceByAccessToken(
- ctx context.Context, token string,
-) (*api.Device, error) {
- return d.devices.selectDeviceByToken(ctx, token)
-}
-
-// GetDeviceByID returns the device matching the given ID.
-// Returns sql.ErrNoRows if no matching device was found.
-func (d *Database) GetDeviceByID(
- ctx context.Context, localpart, deviceID string,
-) (*api.Device, error) {
- return d.devices.selectDeviceByID(ctx, localpart, deviceID)
-}
-
-// GetDevicesByLocalpart returns the devices matching the given localpart.
-func (d *Database) GetDevicesByLocalpart(
- ctx context.Context, localpart string,
-) ([]api.Device, error) {
- return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
-}
-
-func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
- return d.devices.selectDevicesByID(ctx, deviceIDs)
-}
-
-// CreateDevice makes a new device associated with the given user ID localpart.
-// If there is already a device with the same device ID for this user, that access token will be revoked
-// and replaced with the given accessToken. If the given accessToken is already in use for another device,
-// an error will be returned.
-// If no device ID is given one is generated.
-// Returns the device on success.
-func (d *Database) CreateDevice(
- ctx context.Context, localpart string, deviceID *string, accessToken string,
- displayName *string, ipAddr, userAgent string,
-) (dev *api.Device, returnErr error) {
- if deviceID != nil {
- returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- var err error
- // Revoke existing tokens for this device
- if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
- return err
- }
-
- dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
- return err
- })
- } else {
- // We generate device IDs in a loop in case its already taken.
- // We cap this at going round 5 times to ensure we don't spin forever
- var newDeviceID string
- for i := 1; i <= 5; i++ {
- newDeviceID, returnErr = generateDeviceID()
- if returnErr != nil {
- return
- }
-
- returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- var err error
- dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
- return err
- })
- if returnErr == nil {
- return
- }
- }
- }
- return
-}
-
-// generateDeviceID creates a new device id. Returns an error if failed to generate
-// random bytes.
-func generateDeviceID() (string, error) {
- b := make([]byte, deviceIDByteLength)
- _, err := rand.Read(b)
- if err != nil {
- return "", err
- }
- // url-safe no padding
- return base64.RawURLEncoding.EncodeToString(b), nil
-}
-
-// UpdateDevice updates the given device with the display name.
-// Returns SQL error if there are problems and nil on success.
-func (d *Database) UpdateDevice(
- ctx context.Context, localpart, deviceID string, displayName *string,
-) error {
- return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
- })
-}
-
-// RemoveDevice revokes a device by deleting the entry in the database
-// matching with the given device ID and user ID localpart.
-// If the device doesn't exist, it will not return an error
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveDevice(
- ctx context.Context, deviceID, localpart string,
-) error {
- return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
-}
-
-// RemoveDevices revokes one or more devices by deleting the entry in the database
-// matching with the given device IDs and user ID localpart.
-// If the devices don't exist, it will not return an error
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveDevices(
- ctx context.Context, localpart string, devices []string,
-) error {
- return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
-}
-
-// RemoveAllDevices revokes devices by deleting the entry in the
-// database matching the given user ID localpart.
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveAllDevices(
- ctx context.Context, localpart, exceptDeviceID string,
-) (devices []api.Device, err error) {
- err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
- if err != nil {
- return err
- }
- if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
- return
-}
-
-// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
-func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
- return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
- })
-}
-
-// CreateLoginToken generates a token, stores and returns it. The lifetime is
-// determined by the loginTokenLifetime given to the Database constructor.
-func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
- tok, err := generateLoginToken()
- if err != nil {
- return nil, err
- }
- meta := &api.LoginTokenMetadata{
- Token: tok,
- Expiration: time.Now().Add(d.loginTokenLifetime),
- }
-
- err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- return d.loginTokens.insert(ctx, txn, meta, data)
- })
- if err != nil {
- return nil, err
- }
-
- return meta, nil
-}
-
-func generateLoginToken() (string, error) {
- b := make([]byte, loginTokenByteLength)
- _, err := rand.Read(b)
- if err != nil {
- return "", err
- }
- return base64.RawURLEncoding.EncodeToString(b), nil
-}
-
-// RemoveLoginToken removes the named token (and may clean up other expired tokens).
-func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
- return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- return d.loginTokens.deleteByToken(ctx, txn, token)
- })
-}
-
-// GetLoginTokenDataByToken returns the data associated with the given token.
-// May return sql.ErrNoRows.
-func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
- return d.loginTokens.selectByToken(ctx, token)
-}
diff --git a/userapi/storage/devices/storage.go b/userapi/storage/devices/storage.go
deleted file mode 100644
index 15cf8150..00000000
--- a/userapi/storage/devices/storage.go
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//go:build !wasm
-// +build !wasm
-
-package devices
-
-import (
- "fmt"
- "time"
-
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/storage/devices/postgres"
- "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
-// and sets postgres connection parameters. loginTokenLifetime determines how long a
-// login token from CreateLoginToken is valid.
-func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) {
- switch {
- case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
- case dbProperties.ConnectionString.IsPostgres():
- return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime)
- default:
- return nil, fmt.Errorf("unexpected database type")
- }
-}
diff --git a/userapi/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go
deleted file mode 100644
index 3de7880b..00000000
--- a/userapi/storage/devices/storage_wasm.go
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package devices
-
-import (
- "fmt"
- "time"
-
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-func NewDatabase(
- dbProperties *config.DatabaseOptions,
- serverName gomatrixserverlib.ServerName,
- loginTokenLifetime time.Duration,
-) (Database, error) {
- switch {
- case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
- case dbProperties.ConnectionString.IsPostgres():
- return nil, fmt.Errorf("can't use Postgres implementation")
- default:
- return nil, fmt.Errorf("unexpected database type")
- }
-}
diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/interface.go
index a2185774..a131dac4 100644
--- a/userapi/storage/accounts/interface.go
+++ b/userapi/storage/interface.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package accounts
+package storage
import (
"context"
@@ -60,6 +60,35 @@ type Database interface {
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
+
+ GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
+ GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
+ GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
+ GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
+ // CreateDevice makes a new device associated with the given user ID localpart.
+ // If there is already a device with the same device ID for this user, that access token will be revoked
+ // and replaced with the given accessToken. If the given accessToken is already in use for another device,
+ // an error will be returned.
+ // If no device ID is given one is generated.
+ // Returns the device on success.
+ CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
+ UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
+ UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
+ RemoveDevice(ctx context.Context, deviceID, localpart string) error
+ RemoveDevices(ctx context.Context, localpart string, devices []string) error
+ // RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
+ RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
+
+ // CreateLoginToken generates a token, stores and returns it. The lifetime is
+ // determined by the loginTokenLifetime given to the Database constructor.
+ CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
+
+ // RemoveLoginToken removes the named token (and may clean up other expired tokens).
+ RemoveLoginToken(ctx context.Context, token string) error
+
+ // GetLoginTokenDataByToken returns the data associated with the given token.
+ // May return sql.ErrNoRows.
+ GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
}
// Err3PIDInUse is the error returned when trying to save an association involving
diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go
index 8ba890e7..8ba890e7 100644
--- a/userapi/storage/accounts/postgres/account_data_table.go
+++ b/userapi/storage/postgres/account_data_table.go
diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go
index 9e3e456a..9e3e456a 100644
--- a/userapi/storage/accounts/postgres/accounts_table.go
+++ b/userapi/storage/postgres/accounts_table.go
diff --git a/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go b/userapi/storage/postgres/deltas/20200929203058_is_active.go
index 32d3235b..32d3235b 100644
--- a/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go
+++ b/userapi/storage/postgres/deltas/20200929203058_is_active.go
diff --git a/userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go
index 290f854c..1bbb0a9d 100644
--- a/userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go
+++ b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go
@@ -5,13 +5,8 @@ import (
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/pressly/goose"
)
-func LoadFromGoose() {
- goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
-}
-
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
diff --git a/userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go
index 2fae00cb..2fae00cb 100644
--- a/userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go
+++ b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go
diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go
index 7de9f5f9..64cc0b71 100644
--- a/userapi/storage/devices/postgres/devices_table.go
+++ b/userapi/storage/postgres/devices_table.go
@@ -117,6 +117,9 @@ func (s *devicesStatements) execSchema(db *sql.DB) error {
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
+ if err = s.execSchema(db); err != nil {
+ return
+ }
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go
index c1402d4d..c1402d4d 100644
--- a/userapi/storage/accounts/postgres/key_backup_table.go
+++ b/userapi/storage/postgres/key_backup_table.go
diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/postgres/key_backup_version_table.go
index d73447b4..d73447b4 100644
--- a/userapi/storage/accounts/postgres/key_backup_version_table.go
+++ b/userapi/storage/postgres/key_backup_version_table.go
diff --git a/userapi/storage/devices/postgres/logintoken_table.go b/userapi/storage/postgres/logintoken_table.go
index f601fc7d..508a6898 100644
--- a/userapi/storage/devices/postgres/logintoken_table.go
+++ b/userapi/storage/postgres/logintoken_table.go
@@ -51,6 +51,9 @@ CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_exp
// prepare runs statement preparation.
func (s *loginTokenStatements) prepare(db *sql.DB) error {
+ if err := s.execSchema(db); err != nil {
+ return err
+ }
return sqlutil.StatementList{
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
diff --git a/userapi/storage/accounts/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go
index 190d141b..190d141b 100644
--- a/userapi/storage/accounts/postgres/openid_table.go
+++ b/userapi/storage/postgres/openid_table.go
diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go
index 9313864b..9313864b 100644
--- a/userapi/storage/accounts/postgres/profile_table.go
+++ b/userapi/storage/postgres/profile_table.go
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/postgres/storage.go
index d31efd25..73419279 100644
--- a/userapi/storage/accounts/postgres/storage.go
+++ b/userapi/storage/postgres/storage.go
@@ -16,7 +16,9 @@ package postgres
import (
"context"
+ "crypto/rand"
"database/sql"
+ "encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -30,7 +32,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
// Import the postgres database driver.
_ "github.com/lib/pq"
@@ -47,14 +49,23 @@ type Database struct {
threepids threepidStatements
openIDTokens tokenStatements
keyBackupVersions keyBackupVersionStatements
+ devices devicesStatements
+ loginTokens loginTokenStatements
+ loginTokenLifetime time.Duration
keyBackups keyBackupStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
}
+const (
+ // The length of generated device IDs
+ deviceIDByteLength = 6
+ loginTokenByteLength = 32
+)
+
// NewDatabase creates a new accounts and profiles database
-func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
+func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
@@ -63,6 +74,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
serverName: serverName,
db: db,
writer: sqlutil.NewDummyWriter(),
+ loginTokenLifetime: loginTokenLifetime,
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
@@ -74,6 +86,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
}
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m)
+ //deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
@@ -103,6 +116,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
+ if err = d.devices.prepare(db, serverName); err != nil {
+ return nil, err
+ }
+ if err = d.loginTokens.prepare(db); err != nil {
+ return nil, err
+ }
return d, nil
}
@@ -515,3 +534,196 @@ func (d *Database) UpsertBackupKeys(
})
return
}
+
+// GetDeviceByAccessToken returns the device matching the given access token.
+// Returns sql.ErrNoRows if no matching device was found.
+func (d *Database) GetDeviceByAccessToken(
+ ctx context.Context, token string,
+) (*api.Device, error) {
+ return d.devices.selectDeviceByToken(ctx, token)
+}
+
+// GetDeviceByID returns the device matching the given ID.
+// Returns sql.ErrNoRows if no matching device was found.
+func (d *Database) GetDeviceByID(
+ ctx context.Context, localpart, deviceID string,
+) (*api.Device, error) {
+ return d.devices.selectDeviceByID(ctx, localpart, deviceID)
+}
+
+// GetDevicesByLocalpart returns the devices matching the given localpart.
+func (d *Database) GetDevicesByLocalpart(
+ ctx context.Context, localpart string,
+) ([]api.Device, error) {
+ return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
+}
+
+func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
+ return d.devices.selectDevicesByID(ctx, deviceIDs)
+}
+
+// CreateDevice makes a new device associated with the given user ID localpart.
+// If there is already a device with the same device ID for this user, that access token will be revoked
+// and replaced with the given accessToken. If the given accessToken is already in use for another device,
+// an error will be returned.
+// If no device ID is given one is generated.
+// Returns the device on success.
+func (d *Database) CreateDevice(
+ ctx context.Context, localpart string, deviceID *string, accessToken string,
+ displayName *string, ipAddr, userAgent string,
+) (dev *api.Device, returnErr error) {
+ if deviceID != nil {
+ returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var err error
+ // Revoke existing tokens for this device
+ if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
+ return err
+ }
+
+ dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
+ return err
+ })
+ } else {
+ // We generate device IDs in a loop in case its already taken.
+ // We cap this at going round 5 times to ensure we don't spin forever
+ var newDeviceID string
+ for i := 1; i <= 5; i++ {
+ newDeviceID, returnErr = generateDeviceID()
+ if returnErr != nil {
+ return
+ }
+
+ returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var err error
+ dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
+ return err
+ })
+ if returnErr == nil {
+ return
+ }
+ }
+ }
+ return
+}
+
+// generateDeviceID creates a new device id. Returns an error if failed to generate
+// random bytes.
+func generateDeviceID() (string, error) {
+ b := make([]byte, deviceIDByteLength)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ // url-safe no padding
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// UpdateDevice updates the given device with the display name.
+// Returns SQL error if there are problems and nil on success.
+func (d *Database) UpdateDevice(
+ ctx context.Context, localpart, deviceID string, displayName *string,
+) error {
+ return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
+ })
+}
+
+// RemoveDevice revokes a device by deleting the entry in the database
+// matching with the given device ID and user ID localpart.
+// If the device doesn't exist, it will not return an error
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveDevice(
+ ctx context.Context, deviceID, localpart string,
+) error {
+ return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
+
+// RemoveDevices revokes one or more devices by deleting the entry in the database
+// matching with the given device IDs and user ID localpart.
+// If the devices don't exist, it will not return an error
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveDevices(
+ ctx context.Context, localpart string, devices []string,
+) error {
+ return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
+
+// RemoveAllDevices revokes devices by deleting the entry in the
+// database matching the given user ID localpart.
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveAllDevices(
+ ctx context.Context, localpart, exceptDeviceID string,
+) (devices []api.Device, err error) {
+ err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
+ if err != nil {
+ return err
+ }
+ if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+ return
+}
+
+// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
+func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
+ return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
+ })
+}
+
+// CreateLoginToken generates a token, stores and returns it. The lifetime is
+// determined by the loginTokenLifetime given to the Database constructor.
+func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
+ tok, err := generateLoginToken()
+ if err != nil {
+ return nil, err
+ }
+ meta := &api.LoginTokenMetadata{
+ Token: tok,
+ Expiration: time.Now().Add(d.loginTokenLifetime),
+ }
+
+ err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.loginTokens.insert(ctx, txn, meta, data)
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return meta, nil
+}
+
+func generateLoginToken() (string, error) {
+ b := make([]byte, loginTokenByteLength)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// RemoveLoginToken removes the named token (and may clean up other expired tokens).
+func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
+ return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.loginTokens.deleteByToken(ctx, txn, token)
+ })
+}
+
+// GetLoginTokenDataByToken returns the data associated with the given token.
+// May return sql.ErrNoRows.
+func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
+ return d.loginTokens.selectByToken(ctx, token)
+}
diff --git a/userapi/storage/accounts/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go
index 9280fc87..9280fc87 100644
--- a/userapi/storage/accounts/postgres/threepid_table.go
+++ b/userapi/storage/postgres/threepid_table.go
diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go
index 871f996e..871f996e 100644
--- a/userapi/storage/accounts/sqlite3/account_data_table.go
+++ b/userapi/storage/sqlite3/account_data_table.go
diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go
index 5a918e03..5a918e03 100644
--- a/userapi/storage/accounts/sqlite3/accounts_table.go
+++ b/userapi/storage/sqlite3/accounts_table.go
diff --git a/userapi/storage/accounts/sqlite3/constraint_wasm.go b/userapi/storage/sqlite3/constraint_wasm.go
index 6c4ee762..6c4ee762 100644
--- a/userapi/storage/accounts/sqlite3/constraint_wasm.go
+++ b/userapi/storage/sqlite3/constraint_wasm.go
diff --git a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go
index c69614e8..c69614e8 100644
--- a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go
+++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go
diff --git a/userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
index 26209826..ebf90800 100644
--- a/userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
+++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
@@ -5,13 +5,8 @@ import (
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/pressly/goose"
)
-func LoadFromGoose() {
- goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
-}
-
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
diff --git a/userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go
index 9b058ded..9b058ded 100644
--- a/userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go
+++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go
diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go
index 955d8ac7..119ecdf9 100644
--- a/userapi/storage/devices/sqlite3/devices_table.go
+++ b/userapi/storage/sqlite3/devices_table.go
@@ -106,6 +106,9 @@ func (s *devicesStatements) execSchema(db *sql.DB) error {
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = writer
+ if err = s.execSchema(db); err != nil {
+ return
+ }
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go
index 837d38cf..837d38cf 100644
--- a/userapi/storage/accounts/sqlite3/key_backup_table.go
+++ b/userapi/storage/sqlite3/key_backup_table.go
diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/sqlite3/key_backup_version_table.go
index 4211ed0f..4211ed0f 100644
--- a/userapi/storage/accounts/sqlite3/key_backup_version_table.go
+++ b/userapi/storage/sqlite3/key_backup_version_table.go
diff --git a/userapi/storage/devices/sqlite3/logintoken_table.go b/userapi/storage/sqlite3/logintoken_table.go
index 75ef272f..52322b46 100644
--- a/userapi/storage/devices/sqlite3/logintoken_table.go
+++ b/userapi/storage/sqlite3/logintoken_table.go
@@ -51,6 +51,9 @@ CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_exp
// prepare runs statement preparation.
func (s *loginTokenStatements) prepare(db *sql.DB) error {
+ if err := s.execSchema(db); err != nil {
+ return err
+ }
return sqlutil.StatementList{
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
diff --git a/userapi/storage/accounts/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go
index 98c0488b..98c0488b 100644
--- a/userapi/storage/accounts/sqlite3/openid_table.go
+++ b/userapi/storage/sqlite3/openid_table.go
diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go
index a92e9566..a92e9566 100644
--- a/userapi/storage/accounts/sqlite3/profile_table.go
+++ b/userapi/storage/sqlite3/profile_table.go
diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go
index 0bab16ca..56ec1b6a 100644
--- a/userapi/storage/accounts/sqlite3/storage.go
+++ b/userapi/storage/sqlite3/storage.go
@@ -16,7 +16,9 @@ package sqlite3
import (
"context"
+ "crypto/rand"
"database/sql"
+ "encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -31,7 +33,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
)
// Database represents an account database
@@ -47,6 +49,9 @@ type Database struct {
openIDTokens tokenStatements
keyBackupVersions keyBackupVersionStatements
keyBackups keyBackupStatements
+ devices devicesStatements
+ loginTokens loginTokenStatements
+ loginTokenLifetime time.Duration
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
@@ -57,8 +62,14 @@ type Database struct {
threepidsMu sync.Mutex
}
+const (
+ // The length of generated device IDs
+ deviceIDByteLength = 6
+ loginTokenByteLength = 32
+)
+
// NewDatabase creates a new accounts and profiles database
-func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
+func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
@@ -67,6 +78,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
serverName: serverName,
db: db,
writer: sqlutil.NewExclusiveWriter(),
+ loginTokenLifetime: loginTokenLifetime,
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
@@ -78,6 +90,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
}
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m)
+ //deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
@@ -108,6 +121,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
+ if err = d.devices.prepare(db, d.writer, serverName); err != nil {
+ return nil, err
+ }
+ if err = d.loginTokens.prepare(db); err != nil {
+ return nil, err
+ }
return d, nil
}
@@ -547,3 +566,196 @@ func (d *Database) UpsertBackupKeys(
})
return
}
+
+// GetDeviceByAccessToken returns the device matching the given access token.
+// Returns sql.ErrNoRows if no matching device was found.
+func (d *Database) GetDeviceByAccessToken(
+ ctx context.Context, token string,
+) (*api.Device, error) {
+ return d.devices.selectDeviceByToken(ctx, token)
+}
+
+// GetDeviceByID returns the device matching the given ID.
+// Returns sql.ErrNoRows if no matching device was found.
+func (d *Database) GetDeviceByID(
+ ctx context.Context, localpart, deviceID string,
+) (*api.Device, error) {
+ return d.devices.selectDeviceByID(ctx, localpart, deviceID)
+}
+
+// GetDevicesByLocalpart returns the devices matching the given localpart.
+func (d *Database) GetDevicesByLocalpart(
+ ctx context.Context, localpart string,
+) ([]api.Device, error) {
+ return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
+}
+
+func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
+ return d.devices.selectDevicesByID(ctx, deviceIDs)
+}
+
+// CreateDevice makes a new device associated with the given user ID localpart.
+// If there is already a device with the same device ID for this user, that access token will be revoked
+// and replaced with the given accessToken. If the given accessToken is already in use for another device,
+// an error will be returned.
+// If no device ID is given one is generated.
+// Returns the device on success.
+func (d *Database) CreateDevice(
+ ctx context.Context, localpart string, deviceID *string, accessToken string,
+ displayName *string, ipAddr, userAgent string,
+) (dev *api.Device, returnErr error) {
+ if deviceID != nil {
+ returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ var err error
+ // Revoke existing tokens for this device
+ if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
+ return err
+ }
+
+ dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
+ return err
+ })
+ } else {
+ // We generate device IDs in a loop in case its already taken.
+ // We cap this at going round 5 times to ensure we don't spin forever
+ var newDeviceID string
+ for i := 1; i <= 5; i++ {
+ newDeviceID, returnErr = generateDeviceID()
+ if returnErr != nil {
+ return
+ }
+
+ returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ var err error
+ dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
+ return err
+ })
+ if returnErr == nil {
+ return
+ }
+ }
+ }
+ return
+}
+
+// generateDeviceID creates a new device id. Returns an error if failed to generate
+// random bytes.
+func generateDeviceID() (string, error) {
+ b := make([]byte, deviceIDByteLength)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ // url-safe no padding
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// UpdateDevice updates the given device with the display name.
+// Returns SQL error if there are problems and nil on success.
+func (d *Database) UpdateDevice(
+ ctx context.Context, localpart, deviceID string, displayName *string,
+) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
+ })
+}
+
+// RemoveDevice revokes a device by deleting the entry in the database
+// matching with the given device ID and user ID localpart.
+// If the device doesn't exist, it will not return an error
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveDevice(
+ ctx context.Context, deviceID, localpart string,
+) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
+
+// RemoveDevices revokes one or more devices by deleting the entry in the database
+// matching with the given device IDs and user ID localpart.
+// If the devices don't exist, it will not return an error
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveDevices(
+ ctx context.Context, localpart string, devices []string,
+) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
+
+// RemoveAllDevices revokes devices by deleting the entry in the
+// database matching the given user ID localpart.
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveAllDevices(
+ ctx context.Context, localpart, exceptDeviceID string,
+) (devices []api.Device, err error) {
+ err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
+ if err != nil {
+ return err
+ }
+ if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+ return
+}
+
+// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
+func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
+ })
+}
+
+// CreateLoginToken generates a token, stores and returns it. The lifetime is
+// determined by the loginTokenLifetime given to the Database constructor.
+func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
+ tok, err := generateLoginToken()
+ if err != nil {
+ return nil, err
+ }
+ meta := &api.LoginTokenMetadata{
+ Token: tok,
+ Expiration: time.Now().Add(d.loginTokenLifetime),
+ }
+
+ err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ return d.loginTokens.insert(ctx, txn, meta, data)
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return meta, nil
+}
+
+func generateLoginToken() (string, error) {
+ b := make([]byte, loginTokenByteLength)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// RemoveLoginToken removes the named token (and may clean up other expired tokens).
+func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ return d.loginTokens.deleteByToken(ctx, txn, token)
+ })
+}
+
+// GetLoginTokenDataByToken returns the data associated with the given token.
+// May return sql.ErrNoRows.
+func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
+ return d.loginTokens.selectByToken(ctx, token)
+}
diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go
index 9dc0e2d2..9dc0e2d2 100644
--- a/userapi/storage/accounts/sqlite3/threepid_table.go
+++ b/userapi/storage/sqlite3/threepid_table.go
diff --git a/userapi/storage/accounts/storage.go b/userapi/storage/storage.go
index f43f7efd..4711439a 100644
--- a/userapi/storage/accounts/storage.go
+++ b/userapi/storage/storage.go
@@ -15,26 +15,27 @@
//go:build !wasm
// +build !wasm
-package accounts
+package storage
import (
"fmt"
+ "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres"
- "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3"
+ "github.com/matrix-org/dendrite/userapi/storage/postgres"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3"
)
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters
-func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) {
+func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
+ return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
- return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
+ return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
default:
return nil, fmt.Errorf("unexpected database type")
}
diff --git a/userapi/storage/accounts/storage_wasm.go b/userapi/storage/storage_wasm.go
index 11a88a20..701dcd83 100644
--- a/userapi/storage/accounts/storage_wasm.go
+++ b/userapi/storage/storage_wasm.go
@@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package accounts
+package storage
import (
"fmt"
+ "time"
"github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -27,10 +28,11 @@ func NewDatabase(
serverName gomatrixserverlib.ServerName,
bcryptCost int,
openIDTokenLifetimeMS int64,
+ loginTokenLifetime time.Duration,
) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
+ return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
diff --git a/userapi/userapi.go b/userapi/userapi.go
index c7e1f667..4a5793ab 100644
--- a/userapi/userapi.go
+++ b/userapi/userapi.go
@@ -23,18 +23,10 @@ import (
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
- "github.com/matrix-org/dendrite/userapi/storage/devices"
+ "github.com/matrix-org/dendrite/userapi/storage"
"github.com/sirupsen/logrus"
)
-// defaultLoginTokenLifetime determines how old a valid token may be.
-//
-// NOTSPEC: The current spec says "SHOULD be limited to around five
-// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
-// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
-const defaultLoginTokenLifetime = 2 * time.Minute
-
// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
// on the given input API.
func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
@@ -44,26 +36,24 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
- accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
+ accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
) api.UserInternalAPI {
- deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime)
+ db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to device db")
}
- return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI)
+ return newInternalAPI(db, cfg, appServices, keyAPI)
}
func newInternalAPI(
- accountDB accounts.Database,
- deviceDB devices.Database,
+ db storage.Database,
cfg *config.UserAPI,
appServices []config.ApplicationService,
keyAPI keyapi.KeyInternalAPI,
) api.UserInternalAPI {
return &internal.UserInternalAPI{
- AccountDB: accountDB,
- DeviceDB: deviceDB,
+ DB: db,
ServerName: cfg.Matrix.ServerName,
AppServices: appServices,
KeyAPI: keyAPI,
diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go
index 141dd96d..4214c07f 100644
--- a/userapi/userapi_test.go
+++ b/userapi/userapi_test.go
@@ -31,8 +31,7 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/inthttp"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
- "github.com/matrix-org/dendrite/userapi/storage/devices"
+ "github.com/matrix-org/dendrite/userapi/storage"
)
const (
@@ -43,23 +42,19 @@ type apiTestOpts struct {
loginTokenLifetime time.Duration
}
-func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, accounts.Database) {
+func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, storage.Database) {
if opts.loginTokenLifetime == 0 {
- opts.loginTokenLifetime = defaultLoginTokenLifetime
+ opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
}
dbopts := &config.DatabaseOptions{
ConnectionString: "file::memory:",
MaxOpenConnections: 1,
MaxIdleConnections: 1,
}
- accountDB, err := accounts.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS)
+ accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime)
if err != nil {
t.Fatalf("failed to create account DB: %s", err)
}
- deviceDB, err := devices.NewDatabase(dbopts, serverName, opts.loginTokenLifetime)
- if err != nil {
- t.Fatalf("failed to create device DB: %s", err)
- }
cfg := &config.UserAPI{
Matrix: &config.Global{
@@ -67,7 +62,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, a
},
}
- return newInternalAPI(accountDB, deviceDB, cfg, nil, nil), accountDB
+ return newInternalAPI(accountDB, cfg, nil, nil), accountDB
}
func TestQueryProfile(t *testing.T) {