aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage/sqlite3
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-11-11 16:41:37 +0000
committerGitHub <noreply@github.com>2022-11-11 16:41:37 +0000
commit529df30b5649e67a2f98114e6640d259cba53566 (patch)
treebcb994ce79916f14c9a11cd11f32063411332585 /userapi/storage/sqlite3
parente177e0ae73d7cc34ffb9869681a6bf177f805205 (diff)
Virtual hosting schema and logic changes (#2876)
Note that virtual users cannot federate correctly yet.
Diffstat (limited to 'userapi/storage/sqlite3')
-rw-r--r--userapi/storage/sqlite3/account_data_table.go33
-rw-r--r--userapi/storage/sqlite3/accounts_table.go53
-rw-r--r--userapi/storage/sqlite3/deltas/20200929203058_is_active.go1
-rw-r--r--userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go1
-rw-r--r--userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go1
-rw-r--r--userapi/storage/sqlite3/deltas/2022110411000000_server_names.go108
-rw-r--r--userapi/storage/sqlite3/deltas/2022110411000001_server_names.go28
-rw-r--r--userapi/storage/sqlite3/devices_table.go92
-rw-r--r--userapi/storage/sqlite3/notifications_table.go49
-rw-r--r--userapi/storage/sqlite3/openid_table.go15
-rw-r--r--userapi/storage/sqlite3/profile_table.go52
-rw-r--r--userapi/storage/sqlite3/pusher_table.go32
-rw-r--r--userapi/storage/sqlite3/storage.go28
-rw-r--r--userapi/storage/sqlite3/threepid_table.go26
14 files changed, 365 insertions, 154 deletions
diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go
index af12decb..2fbdc573 100644
--- a/userapi/storage/sqlite3/account_data_table.go
+++ b/userapi/storage/sqlite3/account_data_table.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
)
const accountDataSchema = `
@@ -28,27 +29,28 @@ const accountDataSchema = `
CREATE TABLE IF NOT EXISTS userapi_account_datas (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room)
room_id TEXT,
-- The account data type
type TEXT NOT NULL,
-- The account data content
- content TEXT NOT NULL,
-
- PRIMARY KEY(localpart, room_id, type)
+ content TEXT NOT NULL
);
+
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
`
const insertAccountDataSQL = `
- INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
- ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
+ INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
+ ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = $5
`
const selectAccountDataSQL = "" +
- "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
+ "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
const selectAccountDataByTypeSQL = "" +
- "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
+ "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
type accountDataStatements struct {
db *sql.DB
@@ -73,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
}
func (s *accountDataStatements) InsertAccountData(
- ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ roomID, dataType string, content json.RawMessage,
) error {
- _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
+ _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content)
return err
}
func (s *accountDataStatements) SelectAccountData(
- ctx context.Context, localpart string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (
/* global */ map[string]json.RawMessage,
/* rooms */ map[string]map[string]json.RawMessage,
error,
) {
- rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
+ rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return nil, nil, err
}
@@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
}
func (s *accountDataStatements) SelectAccountDataByType(
- ctx context.Context, localpart, roomID, dataType string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ roomID, dataType string,
) (data json.RawMessage, err error) {
var bytes []byte
stmt := s.selectAccountDataByTypeStmt
- if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
+ if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go
index 671c1aa0..f4ebe215 100644
--- a/userapi/storage/sqlite3/accounts_table.go
+++ b/userapi/storage/sqlite3/accounts_table.go
@@ -34,7 +34,8 @@ const accountsSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS userapi_accounts (
-- The Matrix user ID localpart for this account
- localpart TEXT NOT NULL PRIMARY KEY,
+ localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
-- When this account was first created, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account.
@@ -48,25 +49,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
-- TODO:
-- upgraded_ts, devices, any email reset stuff?
);
+
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
`
const insertAccountSQL = "" +
- "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
+ "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" +
- "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
+ "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
const deactivateAccountSQL = "" +
- "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
+ "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
const selectAccountByLocalpartSQL = "" +
- "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
+ "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
const selectPasswordHashSQL = "" +
- "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
+ "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" +
- "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
+ "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1"
type accountsStatements struct {
db *sql.DB
@@ -119,16 +122,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) InsertAccount(
- ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
+ ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
+ hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
var err error
if accountType != api.AccountTypeAppService {
- _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
+ _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
} else {
- _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
+ _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
}
if err != nil {
return nil, err
@@ -136,42 +140,43 @@ func (s *accountsStatements) InsertAccount(
return &api.Account{
Localpart: localpart,
- UserID: userutil.MakeUserID(localpart, s.serverName),
- ServerName: s.serverName,
+ UserID: userutil.MakeUserID(localpart, serverName),
+ ServerName: serverName,
AppServiceID: appserviceID,
AccountType: accountType,
}, nil
}
func (s *accountsStatements) UpdatePassword(
- ctx context.Context, localpart, passwordHash string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
+ passwordHash string,
) (err error) {
- _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
+ _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
return
}
func (s *accountsStatements) DeactivateAccount(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
- _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
+ _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
return
}
func (s *accountsStatements) SelectPasswordHash(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (hash string, err error) {
- err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
+ err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
return
}
func (s *accountsStatements) SelectAccountByLocalpart(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
var acc api.Account
stmt := s.selectAccountByLocalpartStmt
- err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
+ err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
@@ -182,20 +187,18 @@ func (s *accountsStatements) SelectAccountByLocalpart(
acc.AppServiceID = appserviceIDPtr.String
}
- acc.UserID = userutil.MakeUserID(localpart, s.serverName)
- acc.ServerName = s.serverName
-
+ acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
return &acc, nil
}
func (s *accountsStatements) SelectNewNumericLocalpart(
- ctx context.Context, txn *sql.Tx,
+ ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
- err = stmt.QueryRowContext(ctx).Scan(&id)
+ err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
if err == sql.ErrNoRows {
return 1, nil
}
diff --git a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go
index 9158cb36..2de85005 100644
--- a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go
+++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go
@@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
+ server_name TEXT NOT NULL,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT,
diff --git a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
index a9224db6..636ce4ef 100644
--- a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
+++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
@@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
+ server_name TEXT NOT NULL,
created_ts BIGINT,
display_name TEXT,
last_seen_ts BIGINT,
diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go
index 230bc143..471e496c 100644
--- a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go
+++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go
@@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
+ server_name TEXT NOT NULL,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT,
diff --git a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go
new file mode 100644
index 00000000..c11ea684
--- /dev/null
+++ b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go
@@ -0,0 +1,108 @@
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+)
+
+var serverNamesTables = []string{
+ "userapi_accounts",
+ "userapi_account_datas",
+ "userapi_devices",
+ "userapi_notifications",
+ "userapi_openid_tokens",
+ "userapi_profiles",
+ "userapi_pushers",
+ "userapi_threepids",
+}
+
+// These tables have a PRIMARY KEY constraint which we need to drop so
+// that we can recreate a new unique index that contains the server name.
+var serverNamesDropPK = []string{
+ "userapi_accounts",
+ "userapi_account_datas",
+ "userapi_profiles",
+}
+
+// These indices are out of date so let's drop them. They will get recreated
+// automatically.
+var serverNamesDropIndex = []string{
+ "userapi_pusher_localpart_idx",
+ "userapi_pusher_app_id_pushkey_localpart_idx",
+}
+
+// I know what you're thinking: you're wondering "why doesn't this use $1
+// and pass variadic parameters to ExecContext?" — the answer is because
+// PostgreSQL doesn't expect the table name to be specified as a substituted
+// argument in that way so it results in a syntax error in the query.
+
+func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
+ for _, table := range serverNamesTables {
+ q := fmt.Sprintf(
+ "SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;",
+ pq.QuoteIdentifier(table),
+ )
+ var c int
+ if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 {
+ continue
+ }
+ q = fmt.Sprintf(
+ "SELECT COUNT(*) FROM pragma_table_info(%s) WHERE name='server_name'",
+ pq.QuoteIdentifier(table),
+ )
+ if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 1 {
+ logrus.Infof("Table %s already has column, skipping", table)
+ continue
+ }
+ if c == 0 {
+ q = fmt.Sprintf(
+ "ALTER TABLE %s ADD COLUMN server_name TEXT NOT NULL DEFAULT '';",
+ pq.QuoteIdentifier(table),
+ )
+ if _, err := tx.ExecContext(ctx, q); err != nil {
+ return fmt.Errorf("add server name to %q error: %w", table, err)
+ }
+ }
+ }
+ for _, table := range serverNamesDropPK {
+ q := fmt.Sprintf(
+ "SELECT COUNT(name), sql FROM sqlite_schema WHERE type='table' AND name=%s;",
+ pq.QuoteIdentifier(table),
+ )
+ var c int
+ var sql string
+ if err := tx.QueryRowContext(ctx, q).Scan(&c, &sql); err != nil || c == 0 {
+ continue
+ }
+ q = fmt.Sprintf(`
+ %s; -- create temporary table
+ INSERT INTO %s SELECT * FROM %s; -- copy data
+ DROP TABLE %s; -- drop original table
+ ALTER TABLE %s RENAME TO %s; -- rename new table
+ `,
+ strings.Replace(sql, table, table+"_tmp", 1), // create temporary table
+ table+"_tmp", table, // copy data
+ table, // drop original table
+ table+"_tmp", table, // rename new table
+ )
+ if _, err := tx.ExecContext(ctx, q); err != nil {
+ return fmt.Errorf("drop PK from %q error: %w", table, err)
+ }
+ }
+ for _, index := range serverNamesDropIndex {
+ q := fmt.Sprintf(
+ "DROP INDEX IF EXISTS %s;",
+ pq.QuoteIdentifier(index),
+ )
+ if _, err := tx.ExecContext(ctx, q); err != nil {
+ return fmt.Errorf("drop index %q error: %w", index, err)
+ }
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go
new file mode 100644
index 00000000..04a47fa7
--- /dev/null
+++ b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go
@@ -0,0 +1,28 @@
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// I know what you're thinking: you're wondering "why doesn't this use $1
+// and pass variadic parameters to ExecContext?" — the answer is because
+// PostgreSQL doesn't expect the table name to be specified as a substituted
+// argument in that way so it results in a syntax error in the query.
+
+func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
+ for _, table := range serverNamesTables {
+ q := fmt.Sprintf(
+ "UPDATE %s SET server_name = %s WHERE server_name = '';",
+ pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
+ )
+ if _, err := tx.ExecContext(ctx, q); err != nil {
+ return fmt.Errorf("write server names to %q error: %w", table, err)
+ }
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go
index e53a0806..c5db34bd 100644
--- a/userapi/storage/sqlite3/devices_table.go
+++ b/userapi/storage/sqlite3/devices_table.go
@@ -40,49 +40,50 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
+ server_name TEXT NOT NULL,
created_ts BIGINT,
display_name TEXT,
last_seen_ts BIGINT,
ip TEXT,
user_agent TEXT,
- UNIQUE (localpart, device_id)
+ UNIQUE (localpart, server_name, device_id)
);
`
const insertDeviceSQL = "" +
- "INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
- " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
+ "INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
+ " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM userapi_devices"
const selectDeviceByTokenSQL = "" +
- "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
+ "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
- "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
+ "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
const selectDevicesByLocalpartSQL = "" +
- "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
+ "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
- "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
+ "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
const deleteDeviceSQL = "" +
- "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
+ "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
const deleteDevicesByLocalpartSQL = "" +
- "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
+ "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
const deleteDevicesSQL = "" +
- "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)"
+ "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)"
const selectDevicesByIDSQL = "" +
- "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
+ "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
- "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
+ "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
type devicesStatements struct {
db *sql.DB
@@ -135,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) InsertDevice(
- ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
- displayName *string, ipAddr, userAgent string,
+ ctx context.Context, txn *sql.Tx, id string,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ accessToken string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
@@ -146,12 +148,12 @@ func (s *devicesStatements) InsertDevice(
return nil, err
}
sessionID++
- if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
+ if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
return &api.Device{
ID: id,
- UserID: userutil.MakeUserID(localpart, s.serverName),
+ UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
SessionID: sessionID,
LastSeenTS: createdTimeMS,
@@ -161,44 +163,52 @@ func (s *devicesStatements) InsertDevice(
}
func (s *devicesStatements) DeleteDevice(
- ctx context.Context, txn *sql.Tx, id, localpart string,
+ ctx context.Context, txn *sql.Tx, id string,
+ localpart string, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
- _, err := stmt.ExecContext(ctx, id, localpart)
+ _, err := stmt.ExecContext(ctx, id, localpart, serverName)
return err
}
func (s *devicesStatements) DeleteDevices(
- ctx context.Context, txn *sql.Tx, localpart string, devices []string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ devices []string,
) error {
- orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
+ orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1)
prep, err := s.db.Prepare(orig)
if err != nil {
return err
}
stmt := sqlutil.TxStmt(txn, prep)
- params := make([]interface{}, len(devices)+1)
+ params := make([]interface{}, len(devices)+2)
params[0] = localpart
+ params[1] = serverName
for i, v := range devices {
- params[i+1] = v
+ params[i+2] = v
}
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *devicesStatements) DeleteDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
- _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
+ _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
return err
}
func (s *devicesStatements) UpdateDeviceName(
- ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
- _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
+ _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
return err
}
@@ -207,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken(
) (*api.Device, error) {
var dev api.Device
var localpart string
+ var serverName gomatrixserverlib.ServerName
stmt := s.selectDeviceByTokenStmt
- err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
+ err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
if err == nil {
- dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ dev.UserID = userutil.MakeUserID(localpart, serverName)
dev.AccessToken = accessToken
}
return &dev, err
@@ -219,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) SelectDeviceByID(
- ctx context.Context, localpart, deviceID string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ deviceID string,
) (*api.Device, error) {
var dev api.Device
var displayName, ip sql.NullString
stmt := s.selectDeviceByIDStmt
var lastseenTS sql.NullInt64
- err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
+ err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil {
dev.ID = deviceID
- dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ dev.UserID = userutil.MakeUserID(localpart, serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
@@ -243,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID(
}
func (s *devicesStatements) SelectDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
- rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
+ rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
if err != nil {
return devices, err
@@ -276,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
dev.UserAgent = useragent.String
}
- dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
@@ -298,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
var devices []api.Device
var dev api.Device
var localpart string
+ var serverName gomatrixserverlib.ServerName
var displayName sql.NullString
var lastseents sql.NullInt64
for rows.Next() {
- if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
+ if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
return nil, err
}
if displayName.Valid {
@@ -310,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
if lastseents.Valid {
dev.LastSeenTS = lastseents.Int64
}
- dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
-func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
+func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
- _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
+ _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
return err
}
diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go
index a35ec7be..ef39d027 100644
--- a/userapi/storage/sqlite3/notifications_table.go
+++ b/userapi/storage/sqlite3/notifications_table.go
@@ -43,6 +43,7 @@ const notificationSchema = `
CREATE TABLE IF NOT EXISTS userapi_notifications (
id INTEGER PRIMARY KEY AUTOINCREMENT,
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
stream_pos BIGINT NOT NULL,
@@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
read BOOLEAN NOT NULL DEFAULT FALSE
);
-CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
-CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
-CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
`
const insertNotificationSQL = "" +
- "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
+ "INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const deleteNotificationsUpToSQL = "" +
- "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
+ "DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
const updateNotificationReadSQL = "" +
- "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
+ "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
const selectNotificationSQL = "" +
- "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
- "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
- ") AND NOT read ORDER BY localpart, id LIMIT $4"
+ "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
+ "(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read ORDER BY localpart, id LIMIT $5"
const selectNotificationCountSQL = "" +
- "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
- "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
+ "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
+ "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
") AND NOT read"
const selectRoomNotificationCountsSQL = "" +
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
- "WHERE localpart = $1 AND room_id = $2 AND NOT read"
+ "WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
const cleanNotificationsSQL = "" +
"DELETE FROM userapi_notifications WHERE" +
@@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
}
// Insert inserts a notification into the database.
-func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
+func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS
nn := *n
// Clears out fields that have their own columns to (1) shrink the
@@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
if err != nil {
return err
}
- _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
+ _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
return err
}
// DeleteUpTo deletes all previous notifications, up to and including the event.
-func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
- res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
+func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
if err != nil {
return false, err
}
@@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
}
// UpdateRead updates the "read" value for an event.
-func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
- res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
+func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
if err != nil {
return false, err
}
@@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
return nrows > 0, nil
}
-func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
- rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
+func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
if err != nil {
return nil, 0, err
@@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err()
}
-func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
- err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
+func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
+ err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
return
}
-func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
- err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
+func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
+ err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
return
}
diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go
index 875f1a9a..f0642974 100644
--- a/userapi/storage/sqlite3/openid_table.go
+++ b/userapi/storage/sqlite3/openid_table.go
@@ -3,6 +3,7 @@ package sqlite3
import (
"context"
"database/sql"
+ "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
@@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_at_ms BIGINT NOT NULL
);
`
const insertOpenIDTokenSQL = "" +
- "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
+ "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
const selectOpenIDTokenSQL = "" +
- "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
+ "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct {
db *sql.DB
@@ -56,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (
func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context,
txn *sql.Tx,
- token, localpart string,
+ token, localpart string, serverName gomatrixserverlib.ServerName,
expiresAtMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
- _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
+ _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
return
}
@@ -71,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
token string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
+ var localpart string
+ var serverName gomatrixserverlib.ServerName
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
- &openIDTokenAttrs.UserID,
+ &localpart, &serverName,
&openIDTokenAttrs.ExpiresAtMS,
)
+ openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db")
diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go
index b6130a1e..867026d7 100644
--- a/userapi/storage/sqlite3/profile_table.go
+++ b/userapi/storage/sqlite3/profile_table.go
@@ -23,36 +23,40 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
)
const profilesSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS userapi_profiles (
-- The Matrix user ID localpart for this account
- localpart TEXT NOT NULL PRIMARY KEY,
+ localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
-- The display name for this account
display_name TEXT,
-- The URL of the avatar for this account
avatar_url TEXT
);
+
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
`
const insertProfileSQL = "" +
- "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
+ "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
const selectProfileByLocalpartSQL = "" +
- "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
+ "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
const setAvatarURLSQL = "" +
- "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
+ "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" +
" RETURNING display_name"
const setDisplayNameSQL = "" +
- "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
+ "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" +
" RETURNING avatar_url"
const selectProfilesBySearchSQL = "" +
- "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
+ "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct {
db *sql.DB
@@ -83,18 +87,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P
}
func (s *profilesStatements) InsertProfile(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
) error {
- _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
+ _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
return err
}
func (s *profilesStatements) SelectProfileByLocalpart(
- ctx context.Context, localpart string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (*authtypes.Profile, error) {
var profile authtypes.Profile
- err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
- &profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
+ err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
+ &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
)
if err != nil {
return nil, err
@@ -103,13 +109,16 @@ func (s *profilesStatements) SelectProfileByLocalpart(
}
func (s *profilesStatements) SetAvatarURL(
- ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ avatarURL string,
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
- Localpart: localpart,
- AvatarURL: avatarURL,
+ Localpart: localpart,
+ ServerName: string(serverName),
+ AvatarURL: avatarURL,
}
- old, err := s.SelectProfileByLocalpart(ctx, localpart)
+ old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
if err != nil {
return old, false, err
}
@@ -117,18 +126,21 @@ func (s *profilesStatements) SetAvatarURL(
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
- err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
+ err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName)
return profile, true, err
}
func (s *profilesStatements) SetDisplayName(
- ctx context.Context, txn *sql.Tx, localpart string, displayName string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ displayName string,
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
+ ServerName: string(serverName),
DisplayName: displayName,
}
- old, err := s.SelectProfileByLocalpart(ctx, localpart)
+ old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
if err != nil {
return old, false, err
}
@@ -136,7 +148,7 @@ func (s *profilesStatements) SetDisplayName(
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
- err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
+ err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL)
return profile, true, err
}
@@ -154,7 +166,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
for rows.Next() {
var profile authtypes.Profile
- if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
+ if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err
}
if profile.Localpart != s.serverNoticesLocalpart {
diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go
index 4de0a9f0..c9d451dc 100644
--- a/userapi/storage/sqlite3/pusher_table.go
+++ b/userapi/storage/sqlite3/pusher_table.go
@@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
)
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The Matrix user ID localpart for this pusher
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
session_id BIGINT DEFAULT NULL,
profile_tag TEXT,
kind TEXT NOT NULL,
@@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart.
-CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
+CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
-- Pushkey must be unique for a given user and app.
-CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
`
const insertPusherSQL = "" +
- "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
- "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
- "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
+ "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
+ "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
+ "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
const selectPushersSQL = "" +
- "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
+ "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
const deletePusherSQL = "" +
- "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
@@ -95,18 +97,19 @@ type pushersStatements struct {
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
- pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
+ pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
+ localpart string, serverName gomatrixserverlib.ServerName,
) error {
- _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
- logrus.Debugf("Created pusher %d", session_id)
+ _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
return err
}
func (s *pushersStatements) SelectPushers(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) {
pushers := []api.Pusher{}
- rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
+ rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return pushers, err
@@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
// deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher(
- ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
+ ctx context.Context, txn *sql.Tx, appid, pushkey,
+ localpart string, serverName gomatrixserverlib.ServerName,
) error {
- _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
+ _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
return err
}
diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go
index dd33dc0c..85a1f706 100644
--- a/userapi/storage/sqlite3/storage.go
+++ b/userapi/storage/sqlite3/storage.go
@@ -15,6 +15,8 @@
package sqlite3
import (
+ "context"
+ "database/sql"
"fmt"
"time"
@@ -41,18 +43,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables,
})
+ m.AddMigrations(sqlutil.Migration{
+ Version: "userapi: server names",
+ Up: func(ctx context.Context, txn *sql.Tx) error {
+ return deltas.UpServerNames(ctx, txn, serverName)
+ },
+ })
if err = m.Up(base.Context()); err != nil {
return nil, err
}
- accountDataTable, err := NewSQLiteAccountDataTable(db)
- if err != nil {
- return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
- }
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
}
+ accountDataTable, err := NewSQLiteAccountDataTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
+ }
devicesTable, err := NewSQLiteDevicesTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
@@ -93,6 +101,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
}
+
+ m = sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "userapi: server names populate",
+ Up: func(ctx context.Context, txn *sql.Tx) error {
+ return deltas.UpServerNamesPopulate(ctx, txn, serverName)
+ },
+ })
+ if err = m.Up(base.Context()); err != nil {
+ return nil, err
+ }
+
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go
index 73af139d..2db7d588 100644
--- a/userapi/storage/sqlite3/threepid_table.go
+++ b/userapi/storage/sqlite3/threepid_table.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
@@ -34,21 +35,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
medium TEXT NOT NULL DEFAULT 'email',
-- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
PRIMARY KEY(threepid, medium)
);
-CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart);
+CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart, server_name);
`
const selectLocalpartForThreePIDSQL = "" +
- "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
+ "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
- "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
+ "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
const insertThreePIDSQL = "" +
- "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
+ "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
const deleteThreePIDSQL = "" +
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
@@ -79,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
-) (localpart string, err error) {
+) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
- err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
+ err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
if err == sql.ErrNoRows {
- return "", nil
+ return "", "", nil
}
return
}
func (s *threepidStatements) SelectThreePIDsForLocalpart(
- ctx context.Context, localpart string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) {
- rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
+ rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return
}
@@ -113,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
}
func (s *threepidStatements) InsertThreePID(
- ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
+ ctx context.Context, txn *sql.Tx, threepid, medium,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
- _, err = stmt.ExecContext(ctx, threepid, medium, localpart)
+ _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
return err
}