aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2022-04-27 15:05:49 +0200
committerGitHub <noreply@github.com>2022-04-27 15:05:49 +0200
commitf023cdf8c42cc1a4bb850b478dbbf7d901b5e1bd (patch)
tree5698494b5438a721976a1f685dcd29301538a7d7 /userapi/storage
parentd7cc187ec00410b949ffae1625835f8ac9f36c29 (diff)
Add UserAPI storage tests (#2384)
* Add tests for parts of the userapi storage * Add tests for keybackup * Add LoginToken tests * Add OpenID tests * Add profile tests * Add pusher tests * Add ThreePID tests * Add notification tests * Add more device tests, fix numeric localpart query * Fix failing CI * Fix numeric local part query
Diffstat (limited to 'userapi/storage')
-rw-r--r--userapi/storage/interface.go87
-rw-r--r--userapi/storage/postgres/accounts_table.go6
-rw-r--r--userapi/storage/postgres/devices_table.go22
-rw-r--r--userapi/storage/shared/storage.go15
-rw-r--r--userapi/storage/sqlite3/accounts_table.go8
-rw-r--r--userapi/storage/sqlite3/devices_table.go22
-rw-r--r--userapi/storage/storage.go4
-rw-r--r--userapi/storage/storage_test.go539
-rw-r--r--userapi/storage/storage_wasm.go2
9 files changed, 634 insertions, 71 deletions
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go
index b15470dd..a4562cf1 100644
--- a/userapi/storage/interface.go
+++ b/userapi/storage/interface.go
@@ -27,18 +27,24 @@ import (
type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
- SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
}
-type Database interface {
- Profile
- GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
+type Account interface {
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
+ GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
+ GetNewNumericLocalpart(ctx context.Context) (int64, error)
+ CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
+ GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
+ DeactivateAccount(ctx context.Context, localpart string) (err error)
+ SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
+}
+
+type AccountData interface {
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given
@@ -46,26 +52,9 @@ type Database interface {
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
- GetNewNumericLocalpart(ctx context.Context) (int64, error)
- SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
- RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
- GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
- GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
- CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
- GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
- DeactivateAccount(ctx context.Context, localpart string) (err error)
- CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
- GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
-
- // Key backups
- CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
- UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
- DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
- GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
- UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
- GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
- CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
+}
+type Device interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
@@ -79,11 +68,22 @@ type Database interface {
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
- RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
+}
+type KeyBackup interface {
+ CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
+ UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
+ DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
+ GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
+ UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
+ GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
+ CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
+}
+
+type LoginToken interface {
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
@@ -94,19 +94,48 @@ type Database interface {
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
+}
+
+type OpenID interface {
+ CreateOpenIDToken(ctx context.Context, token, userID string) (exp int64, err error)
+ GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
+}
+type Pusher interface {
+ UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
+ GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
+ RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
+ RemovePushers(ctx context.Context, appid, pushkey string) error
+}
+
+type ThreePID interface {
+ SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
+ RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
+ GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
+ GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
+}
+
+type Notification interface {
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error)
- SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error)
+ SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, read bool) (affected bool, err error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
DeleteOldNotifications(ctx context.Context) error
+}
- UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
- GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
- RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
- RemovePushers(ctx context.Context, appid, pushkey string) error
+type Database interface {
+ Account
+ AccountData
+ Device
+ KeyBackup
+ LoginToken
+ Notification
+ OpenID
+ Profile
+ Pusher
+ ThreePID
}
// Err3PIDInUse is the error returned when trying to save an association involving
diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go
index 92311d56..f86812f1 100644
--- a/userapi/storage/postgres/accounts_table.go
+++ b/userapi/storage/postgres/accounts_table.go
@@ -47,8 +47,6 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- TODO:
-- upgraded_ts, devices, any email reset stuff?
);
--- Create sequence for autogenerated numeric usernames
-CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
`
const insertAccountSQL = "" +
@@ -67,7 +65,7 @@ const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
const selectNewNumericLocalpartSQL = "" +
- "SELECT nextval('numeric_username_seq')"
+ "SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'"
type accountsStatements struct {
insertAccountStmt *sql.Stmt
@@ -178,5 +176,5 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
stmt = sqlutil.TxStmt(txn, stmt)
}
err = stmt.QueryRowContext(ctx).Scan(&id)
- return
+ return id + 1, err
}
diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go
index 7bc5dc69..fe8c54e0 100644
--- a/userapi/storage/postgres/devices_table.go
+++ b/userapi/storage/postgres/devices_table.go
@@ -78,7 +78,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
- "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
+ "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@@ -93,7 +93,7 @@ const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
const selectDevicesByIDSQL = "" +
- "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)"
+ "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
@@ -235,16 +235,20 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
}
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
var devices []api.Device
+ var dev api.Device
+ var localpart string
+ var lastseents sql.NullInt64
+ var displayName sql.NullString
for rows.Next() {
- var dev api.Device
- var localpart string
- var displayName sql.NullString
- if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
+ if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
return nil, err
}
if displayName.Valid {
dev.DisplayName = displayName.String
}
+ if lastseents.Valid {
+ dev.LastSeenTS = lastseents.Int64
+ }
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
@@ -262,10 +266,10 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
}
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed")
+ var dev api.Device
+ var lastseents sql.NullInt64
+ var id, displayname, ip, useragent sql.NullString
for rows.Next() {
- var dev api.Device
- var lastseents sql.NullInt64
- var id, displayname, ip, useragent sql.NullString
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
if err != nil {
return devices, err
diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go
index 72ae96ec..f7212e03 100644
--- a/userapi/storage/shared/storage.go
+++ b/userapi/storage/shared/storage.go
@@ -577,21 +577,6 @@ func (d *Database) UpdateDevice(
})
}
-// RemoveDevice revokes a device by deleting the entry in the database
-// matching with the given device ID and user ID localpart.
-// If the device doesn't exist, it will not return an error
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveDevice(
- ctx context.Context, deviceID, localpart string,
-) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
-}
-
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go
index e6c37e58..6c5fe307 100644
--- a/userapi/storage/sqlite3/accounts_table.go
+++ b/userapi/storage/sqlite3/accounts_table.go
@@ -65,7 +65,7 @@ const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" +
- "SELECT COUNT(localpart) FROM account_accounts"
+ "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0"
type accountsStatements struct {
db *sql.DB
@@ -121,6 +121,7 @@ func (s *accountsStatements) InsertAccount(
UserID: userutil.MakeUserID(localpart, s.serverName),
ServerName: s.serverName,
AppServiceID: appserviceID,
+ AccountType: accountType,
}, nil
}
@@ -177,5 +178,8 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
stmt = sqlutil.TxStmt(txn, stmt)
}
err = stmt.QueryRowContext(ctx).Scan(&id)
- return
+ if err == sql.ErrNoRows {
+ return 1, nil
+ }
+ return id + 1, err
}
diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go
index 423640e9..7860bd6a 100644
--- a/userapi/storage/sqlite3/devices_table.go
+++ b/userapi/storage/sqlite3/devices_table.go
@@ -63,7 +63,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
- "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2"
+ "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@@ -78,7 +78,7 @@ const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
const selectDevicesByIDSQL = "" +
- "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
+ "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
@@ -235,10 +235,10 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
return devices, err
}
+ var dev api.Device
+ var lastseents sql.NullInt64
+ var id, displayname, ip, useragent sql.NullString
for rows.Next() {
- var dev api.Device
- var lastseents sql.NullInt64
- var id, displayname, ip, useragent sql.NullString
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
if err != nil {
return devices, err
@@ -279,16 +279,20 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
}
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
var devices []api.Device
+ var dev api.Device
+ var localpart string
+ var displayName sql.NullString
+ var lastseents sql.NullInt64
for rows.Next() {
- var dev api.Device
- var localpart string
- var displayName sql.NullString
- if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
+ if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
return nil, err
}
if displayName.Valid {
dev.DisplayName = displayName.String
}
+ if lastseents.Valid {
+ dev.LastSeenTS = lastseents.Int64
+ }
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go
index f372fe7d..faf1ce75 100644
--- a/userapi/storage/storage.go
+++ b/userapi/storage/storage.go
@@ -28,9 +28,9 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
)
-// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
+// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters
-func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
+func NewUserAPIDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go
new file mode 100644
index 00000000..e6c7d35f
--- /dev/null
+++ b/userapi/storage/storage_test.go
@@ -0,0 +1,539 @@
+package storage_test
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/internal/pushrules"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/stretchr/testify/assert"
+ "golang.org/x/crypto/bcrypt"
+)
+
+const loginTokenLifetime = time.Minute
+
+var (
+ openIDLifetimeMS = time.Minute.Milliseconds()
+ ctx = context.Background()
+)
+
+func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := storage.NewUserAPIDatabase(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
+ if err != nil {
+ t.Fatalf("NewUserAPIDatabase returned %s", err)
+ }
+ return db, close
+}
+
+// Tests storing and getting account data
+func Test_AccountData(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateDatabase(t, dbType)
+ defer close()
+ alice := test.NewUser()
+ localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ assert.NoError(t, err)
+
+ room := test.NewRoom(t, alice)
+ events := room.Events()
+
+ contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
+ err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom)
+ assert.NoError(t, err, "unable to save account data")
+
+ contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
+ err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal)
+ assert.NoError(t, err, "unable to save account data")
+
+ accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read")
+ assert.NoError(t, err, "unable to get account data by type")
+ assert.Equal(t, contentRoom, accountData)
+
+ globalData, roomData, err := db.GetAccountData(ctx, localpart)
+ assert.NoError(t, err)
+ assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
+ assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
+ })
+}
+
+// Tests the creation of accounts
+func Test_Accounts(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateDatabase(t, dbType)
+ defer close()
+ alice := test.NewUser()
+ aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ assert.NoError(t, err)
+
+ accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
+ assert.NoError(t, err, "failed to create account")
+ // verify the newly create account is the same as returned by CreateAccount
+ var accGet *api.Account
+ accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
+ assert.NoError(t, err, "failed to get account by password")
+ assert.Equal(t, accAlice, accGet)
+ accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
+ assert.NoError(t, err, "failed to get account by localpart")
+ assert.Equal(t, accAlice, accGet)
+
+ // check account availability
+ available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
+ assert.NoError(t, err, "failed to checkout account availability")
+ assert.Equal(t, false, available)
+
+ available, err = db.CheckAccountAvailability(ctx, "unusedname")
+ assert.NoError(t, err, "failed to checkout account availability")
+ assert.Equal(t, true, available)
+
+ // get guest account numeric aliceLocalpart
+ first, err := db.GetNewNumericLocalpart(ctx)
+ assert.NoError(t, err, "failed to get new numeric localpart")
+ // Create a new account to verify the numeric localpart is updated
+ _, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
+ assert.NoError(t, err, "failed to create account")
+ second, err := db.GetNewNumericLocalpart(ctx)
+ assert.NoError(t, err)
+ assert.Greater(t, second, first)
+
+ // update password for alice
+ err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
+ assert.NoError(t, err, "failed to update password")
+ accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
+ assert.NoError(t, err, "failed to get account by new password")
+ assert.Equal(t, accAlice, accGet)
+
+ // deactivate account
+ err = db.DeactivateAccount(ctx, aliceLocalpart)
+ assert.NoError(t, err, "failed to deactivate account")
+ // This should fail now, as the account is deactivated
+ _, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
+ assert.Error(t, err, "expected an error, got none")
+
+ _, err = db.GetAccountByLocalpart(ctx, "unusename")
+ assert.Error(t, err, "expected an error for non existent localpart")
+ })
+}
+
+func Test_Devices(t *testing.T) {
+ alice := test.NewUser()
+ localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ assert.NoError(t, err)
+ deviceID := util.RandomString(8)
+ accessToken := util.RandomString(16)
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateDatabase(t, dbType)
+ defer close()
+
+ deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
+ assert.NoError(t, err, "unable to create deviceWithoutID")
+
+ gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
+ assert.NoError(t, err, "unable to get device by id")
+ assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
+
+ gotDeviceAccessToken, err := db.GetDeviceByAccessToken(ctx, accessToken)
+ assert.NoError(t, err, "unable to get device by access token")
+ assert.Equal(t, deviceWithID.ID, gotDeviceAccessToken.ID) // GetDeviceByAccessToken doesn't populate all fields
+
+ // create a device without existing device ID
+ accessToken = util.RandomString(16)
+ deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
+ assert.NoError(t, err, "unable to create deviceWithoutID")
+ gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
+ assert.NoError(t, err, "unable to get device by id")
+ assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
+
+ // Get devices
+ devices, err := db.GetDevicesByLocalpart(ctx, localpart)
+ assert.NoError(t, err, "unable to get devices by localpart")
+ assert.Equal(t, 2, len(devices))
+ deviceIDs := make([]string, 0, len(devices))
+ for _, dev := range devices {
+ deviceIDs = append(deviceIDs, dev.ID)
+ }
+
+ devices2, err := db.GetDevicesByID(ctx, deviceIDs)
+ assert.NoError(t, err, "unable to get devices by id")
+ assert.Equal(t, devices, devices2)
+
+ // Update device
+ newName := "new display name"
+ err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName)
+ assert.NoError(t, err, "unable to update device displayname")
+ err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1")
+ assert.NoError(t, err, "unable to update device last seen")
+
+ deviceWithID.DisplayName = newName
+ deviceWithID.LastSeenIP = "127.0.0.1"
+ deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second)))
+ devices, err = db.GetDevicesByLocalpart(ctx, localpart)
+ assert.NoError(t, err, "unable to get device by id")
+ assert.Equal(t, 2, len(devices))
+ assert.Equal(t, deviceWithID.DisplayName, devices[0].DisplayName)
+ assert.Equal(t, deviceWithID.LastSeenIP, devices[0].LastSeenIP)
+ truncatedTime := gomatrixserverlib.Timestamp(devices[0].LastSeenTS).Time().Truncate(time.Second)
+ assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime))
+
+ // create one more device and remove the devices step by step
+ newDeviceID := util.RandomString(16)
+ accessToken = util.RandomString(16)
+ _, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "")
+ assert.NoError(t, err, "unable to create new device")
+
+ devices, err = db.GetDevicesByLocalpart(ctx, localpart)
+ assert.NoError(t, err, "unable to get device by id")
+ assert.Equal(t, 3, len(devices))
+
+ err = db.RemoveDevices(ctx, localpart, deviceIDs)
+ assert.NoError(t, err, "unable to remove devices")
+ devices, err = db.GetDevicesByLocalpart(ctx, localpart)
+ assert.NoError(t, err, "unable to get device by id")
+ assert.Equal(t, 1, len(devices))
+
+ deleted, err := db.RemoveAllDevices(ctx, localpart, "")
+ assert.NoError(t, err, "unable to remove all devices")
+ assert.Equal(t, 1, len(deleted))
+ assert.Equal(t, newDeviceID, deleted[0].ID)
+ })
+}
+
+func Test_KeyBackup(t *testing.T) {
+ alice := test.NewUser()
+ room := test.NewRoom(t, alice)
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateDatabase(t, dbType)
+ defer close()
+
+ wantAuthData := json.RawMessage("my auth data")
+ wantVersion, err := db.CreateKeyBackup(ctx, alice.ID, "dummyAlgo", wantAuthData)
+ assert.NoError(t, err, "unable to create key backup")
+ // get key backup by version
+ gotVersion, gotAlgo, gotAuthData, _, _, err := db.GetKeyBackup(ctx, alice.ID, wantVersion)
+ assert.NoError(t, err, "unable to get key backup")
+ assert.Equal(t, wantVersion, gotVersion, "backup version mismatch")
+ assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch")
+ assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch")
+
+ // get any key backup
+ gotVersion, gotAlgo, gotAuthData, _, _, err = db.GetKeyBackup(ctx, alice.ID, "")
+ assert.NoError(t, err, "unable to get key backup")
+ assert.Equal(t, wantVersion, gotVersion, "backup version mismatch")
+ assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch")
+ assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch")
+
+ err = db.UpdateKeyBackupAuthData(ctx, alice.ID, wantVersion, json.RawMessage("my updated auth data"))
+ assert.NoError(t, err, "unable to update key backup auth data")
+
+ uploads := []api.InternalKeyBackupSession{
+ {
+ KeyBackupSession: api.KeyBackupSession{
+ IsVerified: true,
+ SessionData: wantAuthData,
+ },
+ RoomID: room.ID,
+ SessionID: "1",
+ },
+ {
+ KeyBackupSession: api.KeyBackupSession{},
+ RoomID: room.ID,
+ SessionID: "2",
+ },
+ }
+ count, _, err := db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads)
+ assert.NoError(t, err, "unable to upsert backup keys")
+ assert.Equal(t, int64(len(uploads)), count, "unexpected backup count")
+
+ // do it again to update a key
+ uploads[1].IsVerified = true
+ count, _, err = db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads[1:])
+ assert.NoError(t, err, "unable to upsert backup keys")
+ assert.Equal(t, int64(len(uploads)), count, "unexpected backup count")
+
+ // get backup keys by session id
+ gotBackupKeys, err := db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "1")
+ assert.NoError(t, err, "unable to get backup keys")
+ assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"])
+
+ // get backup keys by room id
+ gotBackupKeys, err = db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "")
+ assert.NoError(t, err, "unable to get backup keys")
+ assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"])
+
+ gotCount, err := db.CountBackupKeys(ctx, wantVersion, alice.ID)
+ assert.NoError(t, err, "unable to get backup keys count")
+ assert.Equal(t, count, gotCount, "unexpected backup count")
+
+ // finally delete a key
+ exists, err := db.DeleteKeyBackup(ctx, alice.ID, wantVersion)
+ assert.NoError(t, err, "unable to delete key backup")
+ assert.True(t, exists)
+
+ // this key should not exist
+ exists, err = db.DeleteKeyBackup(ctx, alice.ID, "3")
+ assert.NoError(t, err, "unable to delete key backup")
+ assert.False(t, exists)
+ })
+}
+
+func Test_LoginToken(t *testing.T) {
+ alice := test.NewUser()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateDatabase(t, dbType)
+ defer close()
+
+ // create a new token
+ wantLoginToken := &api.LoginTokenData{UserID: alice.ID}
+
+ gotMetadata, err := db.CreateLoginToken(ctx, wantLoginToken)
+ assert.NoError(t, err, "unable to create login token")
+ assert.NotNil(t, gotMetadata)
+ assert.Equal(t, time.Now().Add(loginTokenLifetime).Truncate(loginTokenLifetime), gotMetadata.Expiration.Truncate(loginTokenLifetime))
+
+ // get the new token
+ gotLoginToken, err := db.GetLoginTokenDataByToken(ctx, gotMetadata.Token)
+ assert.NoError(t, err, "unable to get login token")
+ assert.NotNil(t, gotLoginToken)
+ assert.Equal(t, wantLoginToken, gotLoginToken, "unexpected login token")
+
+ // remove the login token again
+ err = db.RemoveLoginToken(ctx, gotMetadata.Token)
+ assert.NoError(t, err, "unable to remove login token")
+
+ // check if the token was actually deleted
+ _, err = db.GetLoginTokenDataByToken(ctx, gotMetadata.Token)
+ assert.Error(t, err, "expected an error, but got none")
+ })
+}
+
+func Test_OpenID(t *testing.T) {
+ alice := test.NewUser()
+ token := util.RandomString(24)
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateDatabase(t, dbType)
+ defer close()
+
+ expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS
+ expires, err := db.CreateOpenIDToken(ctx, token, alice.ID)
+ assert.NoError(t, err, "unable to create OpenID token")
+ assert.Equal(t, expiresAtMS, expires)
+
+ attributes, err := db.GetOpenIDTokenAttributes(ctx, token)
+ assert.NoError(t, err, "unable to get OpenID token attributes")
+ assert.Equal(t, alice.ID, attributes.UserID)
+ assert.Equal(t, expiresAtMS, attributes.ExpiresAtMS)
+ })
+}
+
+func Test_Profile(t *testing.T) {
+ alice := test.NewUser()
+ aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ assert.NoError(t, err)
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateDatabase(t, dbType)
+ defer close()
+
+ // create account, which also creates a profile
+ _, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
+ assert.NoError(t, err, "failed to create account")
+
+ gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
+ assert.NoError(t, err, "unable to get profile by localpart")
+ wantProfile := &authtypes.Profile{Localpart: aliceLocalpart}
+ assert.Equal(t, wantProfile, gotProfile)
+
+ // set avatar & displayname
+ wantProfile.DisplayName = "Alice"
+ wantProfile.AvatarURL = "mxc://aliceAvatar"
+ err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
+ assert.NoError(t, err, "unable to set displayname")
+ err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
+ assert.NoError(t, err, "unable to set avatar url")
+ // verify profile
+ gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart)
+ assert.NoError(t, err, "unable to get profile by localpart")
+ assert.Equal(t, wantProfile, gotProfile)
+
+ // search profiles
+ searchRes, err := db.SearchProfiles(ctx, "Alice", 2)
+ assert.NoError(t, err, "unable to search profiles")
+ assert.Equal(t, 1, len(searchRes))
+ assert.Equal(t, *wantProfile, searchRes[0])
+ })
+}
+
+func Test_Pusher(t *testing.T) {
+ alice := test.NewUser()
+ aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ assert.NoError(t, err)
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateDatabase(t, dbType)
+ defer close()
+
+ appID := util.RandomString(8)
+ var pushKeys []string
+ var gotPushers []api.Pusher
+ for i := 0; i < 2; i++ {
+ pushKey := util.RandomString(8)
+
+ wantPusher := api.Pusher{
+ PushKey: pushKey,
+ Kind: api.HTTPKind,
+ AppID: appID,
+ AppDisplayName: util.RandomString(8),
+ DeviceDisplayName: util.RandomString(8),
+ ProfileTag: util.RandomString(8),
+ Language: util.RandomString(2),
+ }
+ err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart)
+ assert.NoError(t, err, "unable to upsert pusher")
+
+ // check it was actually persisted
+ gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
+ assert.NoError(t, err, "unable to get pushers")
+ assert.Equal(t, i+1, len(gotPushers))
+ assert.Equal(t, wantPusher, gotPushers[i])
+ pushKeys = append(pushKeys, pushKey)
+ }
+
+ // remove single pusher
+ err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart)
+ assert.NoError(t, err, "unable to remove pusher")
+ gotPushers, err := db.GetPushers(ctx, aliceLocalpart)
+ assert.NoError(t, err, "unable to get pushers")
+ assert.Equal(t, 1, len(gotPushers))
+
+ // remove last pusher
+ err = db.RemovePushers(ctx, appID, pushKeys[1])
+ assert.NoError(t, err, "unable to remove pusher")
+ gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
+ assert.NoError(t, err, "unable to get pushers")
+ assert.Equal(t, 0, len(gotPushers))
+ })
+}
+
+func Test_ThreePID(t *testing.T) {
+ alice := test.NewUser()
+ aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ assert.NoError(t, err)
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateDatabase(t, dbType)
+ defer close()
+ threePID := util.RandomString(8)
+ medium := util.RandomString(8)
+ err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium)
+ assert.NoError(t, err, "unable to save threepid association")
+
+ // get the stored threepid
+ gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
+ assert.NoError(t, err, "unable to get localpart for threepid")
+ assert.Equal(t, aliceLocalpart, gotLocalpart)
+
+ threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
+ assert.NoError(t, err, "unable to get threepids for localpart")
+ assert.Equal(t, 1, len(threepids))
+ assert.Equal(t, authtypes.ThreePID{
+ Address: threePID,
+ Medium: medium,
+ }, threepids[0])
+
+ // remove threepid association
+ err = db.RemoveThreePIDAssociation(ctx, threePID, medium)
+ assert.NoError(t, err, "unexpected error")
+
+ // verify it was deleted
+ threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
+ assert.NoError(t, err, "unable to get threepids for localpart")
+ assert.Equal(t, 0, len(threepids))
+ })
+}
+
+func Test_Notification(t *testing.T) {
+ alice := test.NewUser()
+ aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ assert.NoError(t, err)
+ room := test.NewRoom(t, alice)
+ room2 := test.NewRoom(t, alice)
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateDatabase(t, dbType)
+ defer close()
+ // generate some dummy notifications
+ for i := 0; i < 10; i++ {
+ eventID := util.RandomString(16)
+ roomID := room.ID
+ ts := time.Now()
+ if i > 5 {
+ roomID = room2.ID
+ // create some old notifications to test DeleteOldNotifications
+ ts = ts.AddDate(0, -2, 0)
+ }
+ notification := &api.Notification{
+ Actions: []*pushrules.Action{
+ {},
+ },
+ Event: gomatrixserverlib.ClientEvent{
+ Content: gomatrixserverlib.RawJSON("{}"),
+ },
+ Read: false,
+ RoomID: roomID,
+ TS: gomatrixserverlib.AsTimestamp(ts),
+ }
+ err = db.InsertNotification(ctx, aliceLocalpart, eventID, int64(i+1), nil, notification)
+ assert.NoError(t, err, "unable to insert notification")
+ }
+
+ // get notifications
+ count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications)
+ assert.NoError(t, err, "unable to get notification count")
+ assert.Equal(t, int64(10), count)
+ notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications)
+ assert.NoError(t, err, "unable to get notifications")
+ assert.Equal(t, int64(10), count)
+ assert.Equal(t, 10, len(notifs))
+ // ... for a specific room
+ total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
+ assert.NoError(t, err, "unable to get notifications for room")
+ assert.Equal(t, int64(4), total)
+
+ // mark notification as read
+ affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true)
+ assert.NoError(t, err, "unable to set notifications read")
+ assert.True(t, affected)
+
+ // this should delete 2 notifications
+ affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8)
+ assert.NoError(t, err, "unable to set notifications read")
+ assert.True(t, affected)
+
+ total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
+ assert.NoError(t, err, "unable to get notifications for room")
+ assert.Equal(t, int64(2), total)
+
+ // delete old notifications
+ err = db.DeleteOldNotifications(ctx)
+ assert.NoError(t, err)
+
+ // this should now return 0 notifications
+ total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
+ assert.NoError(t, err, "unable to get notifications for room")
+ assert.Equal(t, int64(0), total)
+ })
+}
diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go
index 779f7756..a8e6f031 100644
--- a/userapi/storage/storage_wasm.go
+++ b/userapi/storage/storage_wasm.go
@@ -23,7 +23,7 @@ import (
"github.com/matrix-org/gomatrixserverlib"
)
-func NewDatabase(
+func NewUserAPIDatabase(
dbProperties *config.DatabaseOptions,
serverName gomatrixserverlib.ServerName,
bcryptCost int,