diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-11-11 16:41:37 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-11 16:41:37 +0000 |
commit | 529df30b5649e67a2f98114e6640d259cba53566 (patch) | |
tree | bcb994ce79916f14c9a11cd11f32063411332585 /userapi/storage/sqlite3 | |
parent | e177e0ae73d7cc34ffb9869681a6bf177f805205 (diff) |
Virtual hosting schema and logic changes (#2876)
Note that virtual users cannot federate correctly yet.
Diffstat (limited to 'userapi/storage/sqlite3')
-rw-r--r-- | userapi/storage/sqlite3/account_data_table.go | 33 | ||||
-rw-r--r-- | userapi/storage/sqlite3/accounts_table.go | 53 | ||||
-rw-r--r-- | userapi/storage/sqlite3/deltas/20200929203058_is_active.go | 1 | ||||
-rw-r--r-- | userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go | 1 | ||||
-rw-r--r-- | userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go | 1 | ||||
-rw-r--r-- | userapi/storage/sqlite3/deltas/2022110411000000_server_names.go | 108 | ||||
-rw-r--r-- | userapi/storage/sqlite3/deltas/2022110411000001_server_names.go | 28 | ||||
-rw-r--r-- | userapi/storage/sqlite3/devices_table.go | 92 | ||||
-rw-r--r-- | userapi/storage/sqlite3/notifications_table.go | 49 | ||||
-rw-r--r-- | userapi/storage/sqlite3/openid_table.go | 15 | ||||
-rw-r--r-- | userapi/storage/sqlite3/profile_table.go | 52 | ||||
-rw-r--r-- | userapi/storage/sqlite3/pusher_table.go | 32 | ||||
-rw-r--r-- | userapi/storage/sqlite3/storage.go | 28 | ||||
-rw-r--r-- | userapi/storage/sqlite3/threepid_table.go | 26 |
14 files changed, 365 insertions, 154 deletions
diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go index af12decb..2fbdc573 100644 --- a/userapi/storage/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -28,27 +29,28 @@ const accountDataSchema = ` CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The room ID for this data (empty string if not specific to a room) room_id TEXT, -- The account data type type TEXT NOT NULL, -- The account data content - content TEXT NOT NULL, - - PRIMARY KEY(localpart, room_id, type) + content TEXT NOT NULL ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type); ` const insertAccountDataSQL = ` - INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) - ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 + INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5) + ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = $5 ` const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" + "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2" const selectAccountDataByTypeSQL = "" + - "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" + "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4" type accountDataStatements struct { db *sql.DB @@ -73,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { } func (s *accountDataStatements) InsertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, content json.RawMessage, ) error { - _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) + _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content) return err } func (s *accountDataStatements) SelectAccountData( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) ( /* global */ map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage, error, ) { - rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) + rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName) if err != nil { return nil, nil, err } @@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData( } func (s *accountDataStatements) SelectAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte stmt := s.selectAccountDataByTypeStmt - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 671c1aa0..f4ebe215 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -34,7 +34,8 @@ const accountsSchema = ` -- Stores data about accounts. CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When this account was first created, as a unix timestamp (ms resolution). created_ts BIGINT NOT NULL, -- The password hash for this account. Can be NULL if this is a passwordless account. @@ -48,25 +49,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts ( -- TODO: -- upgraded_ts, devices, any email reset stuff? ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name); ` const insertAccountSQL = "" + - "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)" const updatePasswordSQL = "" + - "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3" const deactivateAccountSQL = "" + - "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" + "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0" + "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1" type accountsStatements struct { db *sql.DB @@ -119,16 +122,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) InsertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, + hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt var err error if accountType != api.AccountTypeAppService { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType) } else { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType) } if err != nil { return nil, err @@ -136,42 +140,43 @@ func (s *accountsStatements) InsertAccount( return &api.Account{ Localpart: localpart, - UserID: userutil.MakeUserID(localpart, s.serverName), - ServerName: s.serverName, + UserID: userutil.MakeUserID(localpart, serverName), + ServerName: serverName, AppServiceID: appserviceID, AccountType: accountType, }, nil } func (s *accountsStatements) UpdatePassword( - ctx context.Context, localpart, passwordHash string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + passwordHash string, ) (err error) { - _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName) return } func (s *accountsStatements) DeactivateAccount( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) + _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName) return } func (s *accountsStatements) SelectPasswordHash( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (hash string, err error) { - err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash) return } func (s *accountsStatements) SelectAccountByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) + err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -182,20 +187,18 @@ func (s *accountsStatements) SelectAccountByLocalpart( acc.AppServiceID = appserviceIDPtr.String } - acc.UserID = userutil.MakeUserID(localpart, s.serverName) - acc.ServerName = s.serverName - + acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName) return &acc, nil } func (s *accountsStatements) SelectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { stmt = sqlutil.TxStmt(txn, stmt) } - err = stmt.QueryRowContext(ctx).Scan(&id) + err = stmt.QueryRowContext(ctx, serverName).Scan(&id) if err == sql.ErrNoRows { return 1, nil } diff --git a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go index 9158cb36..2de85005 100644 --- a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go +++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go @@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error { ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, + server_name TEXT NOT NULL, created_ts BIGINT NOT NULL, password_hash TEXT, appservice_id TEXT, diff --git a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go index a9224db6..636ce4ef 100644 --- a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go @@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { session_id INTEGER, device_id TEXT , localpart TEXT , + server_name TEXT NOT NULL, created_ts BIGINT, display_name TEXT, last_seen_ts BIGINT, diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go index 230bc143..471e496c 100644 --- a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go +++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go @@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, + server_name TEXT NOT NULL, created_ts BIGINT NOT NULL, password_hash TEXT, appservice_id TEXT, diff --git a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go new file mode 100644 index 00000000..c11ea684 --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go @@ -0,0 +1,108 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +var serverNamesTables = []string{ + "userapi_accounts", + "userapi_account_datas", + "userapi_devices", + "userapi_notifications", + "userapi_openid_tokens", + "userapi_profiles", + "userapi_pushers", + "userapi_threepids", +} + +// These tables have a PRIMARY KEY constraint which we need to drop so +// that we can recreate a new unique index that contains the server name. +var serverNamesDropPK = []string{ + "userapi_accounts", + "userapi_account_datas", + "userapi_profiles", +} + +// These indices are out of date so let's drop them. They will get recreated +// automatically. +var serverNamesDropIndex = []string{ + "userapi_pusher_localpart_idx", + "userapi_pusher_app_id_pushkey_localpart_idx", +} + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;", + pq.QuoteIdentifier(table), + ) + var c int + if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 { + continue + } + q = fmt.Sprintf( + "SELECT COUNT(*) FROM pragma_table_info(%s) WHERE name='server_name'", + pq.QuoteIdentifier(table), + ) + if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 1 { + logrus.Infof("Table %s already has column, skipping", table) + continue + } + if c == 0 { + q = fmt.Sprintf( + "ALTER TABLE %s ADD COLUMN server_name TEXT NOT NULL DEFAULT '';", + pq.QuoteIdentifier(table), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("add server name to %q error: %w", table, err) + } + } + } + for _, table := range serverNamesDropPK { + q := fmt.Sprintf( + "SELECT COUNT(name), sql FROM sqlite_schema WHERE type='table' AND name=%s;", + pq.QuoteIdentifier(table), + ) + var c int + var sql string + if err := tx.QueryRowContext(ctx, q).Scan(&c, &sql); err != nil || c == 0 { + continue + } + q = fmt.Sprintf(` + %s; -- create temporary table + INSERT INTO %s SELECT * FROM %s; -- copy data + DROP TABLE %s; -- drop original table + ALTER TABLE %s RENAME TO %s; -- rename new table + `, + strings.Replace(sql, table, table+"_tmp", 1), // create temporary table + table+"_tmp", table, // copy data + table, // drop original table + table+"_tmp", table, // rename new table + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop PK from %q error: %w", table, err) + } + } + for _, index := range serverNamesDropIndex { + q := fmt.Sprintf( + "DROP INDEX IF EXISTS %s;", + pq.QuoteIdentifier(index), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop index %q error: %w", index, err) + } + } + return nil +} diff --git a/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go new file mode 100644 index 00000000..04a47fa7 --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go @@ -0,0 +1,28 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" +) + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "UPDATE %s SET server_name = %s WHERE server_name = '';", + pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("write server names to %q error: %w", table, err) + } + } + return nil +} diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index e53a0806..c5db34bd 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -40,49 +40,50 @@ CREATE TABLE IF NOT EXISTS userapi_devices ( session_id INTEGER, device_id TEXT , localpart TEXT , + server_name TEXT NOT NULL, created_ts BIGINT, display_name TEXT, last_seen_ts BIGINT, ip TEXT, user_agent TEXT, - UNIQUE (localpart, device_id) + UNIQUE (localpart, server_name, device_id) ); ` const insertDeviceSQL = "" + - "INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + "INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" const selectDevicesCountSQL = "" + "SELECT COUNT(access_token) FROM userapi_devices" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" const deleteDeviceSQL = "" + - "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3" const deleteDevicesSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" type devicesStatements struct { db *sql.DB @@ -135,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // Returns an error if the user already has a device with the given device ID. // Returns the device on success. func (s *devicesStatements) InsertDevice( - ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, - displayName *string, ipAddr, userAgent string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 @@ -146,12 +148,12 @@ func (s *devicesStatements) InsertDevice( return nil, err } sessionID++ - if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { + if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } return &api.Device{ ID: id, - UserID: userutil.MakeUserID(localpart, s.serverName), + UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, SessionID: sessionID, LastSeenTS: createdTimeMS, @@ -161,44 +163,52 @@ func (s *devicesStatements) InsertDevice( } func (s *devicesStatements) DeleteDevice( - ctx context.Context, txn *sql.Tx, id, localpart string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) + _, err := stmt.ExecContext(ctx, id, localpart, serverName) return err } func (s *devicesStatements) DeleteDevices( - ctx context.Context, txn *sql.Tx, localpart string, devices []string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + devices []string, ) error { - orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) + orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1) prep, err := s.db.Prepare(orig) if err != nil { return err } stmt := sqlutil.TxStmt(txn, prep) - params := make([]interface{}, len(devices)+1) + params := make([]interface{}, len(devices)+2) params[0] = localpart + params[1] = serverName for i, v := range devices { - params[i+1] = v + params[i+2] = v } _, err = stmt.ExecContext(ctx, params...) return err } func (s *devicesStatements) DeleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) + _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID) return err } func (s *devicesStatements) UpdateDeviceName( - ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID) return err } @@ -207,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken( ) (*api.Device, error) { var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName stmt := s.selectDeviceByTokenStmt - err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName) if err == nil { - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) dev.AccessToken = accessToken } return &dev, err @@ -219,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID func (s *devicesStatements) SelectDeviceByID( - ctx context.Context, localpart, deviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, ) (*api.Device, error) { var dev api.Device var displayName, ip sql.NullString stmt := s.selectDeviceByIDStmt var lastseenTS sql.NullInt64 - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) + err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip) if err == nil { dev.ID = deviceID - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) if displayName.Valid { dev.DisplayName = displayName.String } @@ -243,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID( } func (s *devicesStatements) SelectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID) if err != nil { return devices, err @@ -276,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( dev.UserAgent = useragent.String } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } @@ -298,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var devices []api.Device var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName var displayName sql.NullString var lastseents sql.NullInt64 for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { return nil, err } if displayName.Valid { @@ -310,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s if lastseents.Valid { dev.LastSeenTS = lastseents.Int64 } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID) return err } diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go index a35ec7be..ef39d027 100644 --- a/userapi/storage/sqlite3/notifications_table.go +++ b/userapi/storage/sqlite3/notifications_table.go @@ -43,6 +43,7 @@ const notificationSchema = ` CREATE TABLE IF NOT EXISTS userapi_notifications ( id INTEGER PRIMARY KEY AUTOINCREMENT, localpart TEXT NOT NULL, + server_name TEXT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL, stream_pos BIGINT NOT NULL, @@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications ( read BOOLEAN NOT NULL DEFAULT FALSE ); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id); ` const insertNotificationSQL = "" + - "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + "INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const deleteNotificationsUpToSQL = "" + - "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + "DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4" const updateNotificationReadSQL = "" + - "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1" const selectNotificationSQL = "" + - "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + - "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + - ") AND NOT read ORDER BY localpart, id LIMIT $4" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" + + "(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $5" const selectNotificationCountSQL = "" + - "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + - "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + ") AND NOT read" const selectRoomNotificationCountsSQL = "" + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + - "WHERE localpart = $1 AND room_id = $2 AND NOT read" + "WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read" const cleanNotificationsSQL = "" + "DELETE FROM userapi_notifications WHERE" + @@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local if err != nil { return err } - _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs)) return err } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l return nrows > 0, nil } -func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { - rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit) if err != nil { return nil, 0, err @@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { - err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) { + err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count) return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { - err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) { + err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight) return } diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go index 875f1a9a..f0642974 100644 --- a/userapi/storage/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -3,6 +3,7 @@ package sqlite3 import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" @@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( token TEXT NOT NULL PRIMARY KEY, -- The Matrix user ID for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When the token expires, as a unix timestamp (ms resolution). token_expires_at_ms BIGINT NOT NULL ); ` const insertOpenIDTokenSQL = "" + - "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" + "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { db *sql.DB @@ -56,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) ( func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, - token, localpart string, + token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) - _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS) return } @@ -71,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( token string, ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes + var localpart string + var serverName gomatrixserverlib.ServerName err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( - &openIDTokenAttrs.UserID, + &localpart, &serverName, &openIDTokenAttrs.ExpiresAtMS, ) + openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve token from the db") diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index b6130a1e..867026d7 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -23,36 +23,40 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const profilesSchema = ` -- Stores data about accounts profiles. CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The display name for this account display_name TEXT, -- The URL of the avatar for this account avatar_url TEXT ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name); ` const insertProfileSQL = "" + - "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2" const setAvatarURLSQL = "" + - "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" + " RETURNING display_name" const setDisplayNameSQL = "" + - "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" + " RETURNING avatar_url" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { db *sql.DB @@ -83,18 +87,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P } func (s *profilesStatements) InsertProfile( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "") return err } func (s *profilesStatements) SelectProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { var profile authtypes.Profile - err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( - &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( + &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL, ) if err != nil { return nil, err @@ -103,13 +109,16 @@ func (s *profilesStatements) SelectProfileByLocalpart( } func (s *profilesStatements) SetAvatarURL( - ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + avatarURL string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ - Localpart: localpart, - AvatarURL: avatarURL, + Localpart: localpart, + ServerName: string(serverName), + AvatarURL: avatarURL, } - old, err := s.SelectProfileByLocalpart(ctx, localpart) + old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName) if err != nil { return old, false, err } @@ -117,18 +126,21 @@ func (s *profilesStatements) SetAvatarURL( return old, false, nil } stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) + err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName) return profile, true, err } func (s *profilesStatements) SetDisplayName( - ctx context.Context, txn *sql.Tx, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + displayName string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, + ServerName: string(serverName), DisplayName: displayName, } - old, err := s.SelectProfileByLocalpart(ctx, localpart) + old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName) if err != nil { return old, false, err } @@ -136,7 +148,7 @@ func (s *profilesStatements) SetDisplayName( return old, false, nil } stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) + err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL) return profile, true, err } @@ -154,7 +166,7 @@ func (s *profilesStatements) SelectProfilesBySearch( defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") for rows.Next() { var profile authtypes.Profile - if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil { return nil, err } if profile.Localpart != s.serverNoticesLocalpart { diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go index 4de0a9f0..c9d451dc 100644 --- a/userapi/storage/sqlite3/pusher_table.go +++ b/userapi/storage/sqlite3/pusher_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( id INTEGER PRIMARY KEY AUTOINCREMENT, -- The Matrix user ID localpart for this pusher localpart TEXT NOT NULL, + server_name TEXT NOT NULL, session_id BIGINT DEFAULT NULL, profile_tag TEXT, kind TEXT NOT NULL, @@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); -- For faster retrieving by localpart. -CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name); -- Pushkey must be unique for a given user and app. -CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name); ` const insertPusherSQL = "" + - "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + - "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + - "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" + + "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12" const selectPushersSQL = "" + - "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2" const deletePusherSQL = "" + - "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4" const deletePushersByAppIdAndPushKeySQL = "" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" @@ -95,18 +97,19 @@ type pushersStatements struct { // Returns nil error success. func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, - pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, + pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) - logrus.Debugf("Created pusher %d", session_id) + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) return err } func (s *pushersStatements) SelectPushers( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Pusher, error) { pushers := []api.Pusher{} - rows, err := s.selectPushersStmt.QueryContext(ctx, localpart) + rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName) if err != nil { return pushers, err @@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers( // deletePusher removes a single pusher by pushkey and user localpart. func (s *pushersStatements) DeletePusher( - ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, + ctx context.Context, txn *sql.Tx, appid, pushkey, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) + _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName) return err } diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index dd33dc0c..85a1f706 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -15,6 +15,8 @@ package sqlite3 import ( + "context" + "database/sql" "fmt" "time" @@ -41,18 +43,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Up: deltas.UpRenameTables, Down: deltas.DownRenameTables, }) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNames(ctx, txn, serverName) + }, + }) if err = m.Up(base.Context()); err != nil { return nil, err } - accountDataTable, err := NewSQLiteAccountDataTable(db) - if err != nil { - return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) - } accountsTable, err := NewSQLiteAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) } + accountDataTable, err := NewSQLiteAccountDataTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) + } devicesTable, err := NewSQLiteDevicesTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err) @@ -93,6 +101,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err) } + + m = sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names populate", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNamesPopulate(ctx, txn, serverName) + }, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go index 73af139d..2db7d588 100644 --- a/userapi/storage/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -34,21 +35,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids ( medium TEXT NOT NULL DEFAULT 'email', -- The localpart of the Matrix user ID associated to this 3PID localpart TEXT NOT NULL, + server_name TEXT NOT NULL, PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart); +CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart, server_name); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" + "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2" const insertThreePIDSQL = "" + - "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)" const deleteThreePIDSQL = "" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" @@ -79,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, -) (localpart string, err error) { +) (localpart string, serverName gomatrixserverlib.ServerName, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) - err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName) if err == sql.ErrNoRows { - return "", nil + return "", "", nil } return } func (s *threepidStatements) SelectThreePIDsForLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (threepids []authtypes.ThreePID, err error) { - rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName) if err != nil { return } @@ -113,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( } func (s *threepidStatements) InsertThreePID( - ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, + ctx context.Context, txn *sql.Tx, threepid, medium, + localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) - _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName) return err } |