aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-11-11 16:41:37 +0000
committerGitHub <noreply@github.com>2022-11-11 16:41:37 +0000
commit529df30b5649e67a2f98114e6640d259cba53566 (patch)
treebcb994ce79916f14c9a11cd11f32063411332585 /userapi
parente177e0ae73d7cc34ffb9869681a6bf177f805205 (diff)
Virtual hosting schema and logic changes (#2876)
Note that virtual users cannot federate correctly yet.
Diffstat (limited to 'userapi')
-rw-r--r--userapi/api/api.go55
-rw-r--r--userapi/api/api_trace.go4
-rw-r--r--userapi/consumers/clientapi.go4
-rw-r--r--userapi/consumers/roomserver.go58
-rw-r--r--userapi/internal/api.go101
-rw-r--r--userapi/internal/api_logintoken.go2
-rw-r--r--userapi/inthttp/client.go3
-rw-r--r--userapi/inthttp/server.go15
-rw-r--r--userapi/producers/syncapi.go4
-rw-r--r--userapi/storage/interface.go66
-rw-r--r--userapi/storage/postgres/account_data_table.go33
-rw-r--r--userapi/storage/postgres/accounts_table.go57
-rw-r--r--userapi/storage/postgres/deltas/2022110411000000_server_names.go81
-rw-r--r--userapi/storage/postgres/deltas/2022110411000001_server_names.go28
-rw-r--r--userapi/storage/postgres/devices_table.go88
-rw-r--r--userapi/storage/postgres/notifications_table.go49
-rw-r--r--userapi/storage/postgres/openid_table.go15
-rw-r--r--userapi/storage/postgres/profile_table.go48
-rw-r--r--userapi/storage/postgres/pusher_table.go32
-rw-r--r--userapi/storage/postgres/storage.go28
-rw-r--r--userapi/storage/postgres/threepid_table.go26
-rw-r--r--userapi/storage/shared/storage.go204
-rw-r--r--userapi/storage/sqlite3/account_data_table.go33
-rw-r--r--userapi/storage/sqlite3/accounts_table.go53
-rw-r--r--userapi/storage/sqlite3/deltas/20200929203058_is_active.go1
-rw-r--r--userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go1
-rw-r--r--userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go1
-rw-r--r--userapi/storage/sqlite3/deltas/2022110411000000_server_names.go108
-rw-r--r--userapi/storage/sqlite3/deltas/2022110411000001_server_names.go28
-rw-r--r--userapi/storage/sqlite3/devices_table.go92
-rw-r--r--userapi/storage/sqlite3/notifications_table.go49
-rw-r--r--userapi/storage/sqlite3/openid_table.go15
-rw-r--r--userapi/storage/sqlite3/profile_table.go52
-rw-r--r--userapi/storage/sqlite3/pusher_table.go32
-rw-r--r--userapi/storage/sqlite3/storage.go28
-rw-r--r--userapi/storage/sqlite3/threepid_table.go26
-rw-r--r--userapi/storage/storage_test.go132
-rw-r--r--userapi/storage/tables/interface.go68
-rw-r--r--userapi/storage/tables/stats_table_test.go17
-rw-r--r--userapi/userapi_test.go10
-rw-r--r--userapi/util/devices.go8
-rw-r--r--userapi/util/notify.go7
42 files changed, 1108 insertions, 654 deletions
diff --git a/userapi/api/api.go b/userapi/api/api.go
index 8d7f783d..d3f5aefc 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -78,7 +78,7 @@ type ClientUserAPI interface {
QueryAcccessTokenAPI
LoginTokenInternalAPI
UserLoginAPI
- QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
+ QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
@@ -335,9 +335,10 @@ type PerformAccountCreationResponse struct {
// PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformPasswordUpdateRequest struct {
- Localpart string // Required: The localpart for this account.
- Password string // Required: The new password to set.
- LogoutDevices bool // Optional: Whether to log out all user devices.
+ Localpart string // Required: The localpart for this account.
+ ServerName gomatrixserverlib.ServerName // Required: The domain for this account.
+ Password string // Required: The new password to set.
+ LogoutDevices bool // Optional: Whether to log out all user devices.
}
// PerformAccountCreationResponse is the response for PerformAccountCreation
@@ -518,7 +519,8 @@ const (
)
type QueryPushersRequest struct {
- Localpart string
+ Localpart string
+ ServerName gomatrixserverlib.ServerName
}
type QueryPushersResponse struct {
@@ -526,14 +528,16 @@ type QueryPushersResponse struct {
}
type PerformPusherSetRequest struct {
- Pusher // Anonymous field because that's how clientapi unmarshals it.
- Localpart string
- Append bool `json:"append"`
+ Pusher // Anonymous field because that's how clientapi unmarshals it.
+ Localpart string
+ ServerName gomatrixserverlib.ServerName
+ Append bool `json:"append"`
}
type PerformPusherDeletionRequest struct {
- Localpart string
- SessionID int64
+ Localpart string
+ ServerName gomatrixserverlib.ServerName
+ SessionID int64
}
// Pusher represents a push notification subscriber
@@ -571,10 +575,11 @@ type QueryPushRulesResponse struct {
}
type QueryNotificationsRequest struct {
- Localpart string `json:"localpart"` // Required.
- From string `json:"from,omitempty"`
- Limit int `json:"limit,omitempty"`
- Only string `json:"only,omitempty"`
+ Localpart string `json:"localpart"` // Required.
+ ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required.
+ From string `json:"from,omitempty"`
+ Limit int `json:"limit,omitempty"`
+ Only string `json:"only,omitempty"`
}
type QueryNotificationsResponse struct {
@@ -601,12 +606,17 @@ type PerformSetAvatarURLResponse struct {
Changed bool `json:"changed"`
}
+type QueryNumericLocalpartRequest struct {
+ ServerName gomatrixserverlib.ServerName
+}
+
type QueryNumericLocalpartResponse struct {
ID int64
}
type QueryAccountAvailabilityRequest struct {
- Localpart string
+ Localpart string
+ ServerName gomatrixserverlib.ServerName
}
type QueryAccountAvailabilityResponse struct {
@@ -614,7 +624,9 @@ type QueryAccountAvailabilityResponse struct {
}
type QueryAccountByPasswordRequest struct {
- Localpart, PlaintextPassword string
+ Localpart string
+ ServerName gomatrixserverlib.ServerName
+ PlaintextPassword string
}
type QueryAccountByPasswordResponse struct {
@@ -638,11 +650,13 @@ type QueryLocalpartForThreePIDRequest struct {
}
type QueryLocalpartForThreePIDResponse struct {
- Localpart string
+ Localpart string
+ ServerName gomatrixserverlib.ServerName
}
type QueryThreePIDsForLocalpartRequest struct {
- Localpart string
+ Localpart string
+ ServerName gomatrixserverlib.ServerName
}
type QueryThreePIDsForLocalpartResponse struct {
@@ -652,5 +666,8 @@ type QueryThreePIDsForLocalpartResponse struct {
type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest
type PerformSaveThreePIDAssociationRequest struct {
- ThreePID, Localpart, Medium string
+ ThreePID string
+ Localpart string
+ ServerName gomatrixserverlib.ServerName
+ Medium string
}
diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go
index 90834f7e..ce661770 100644
--- a/userapi/api/api_trace.go
+++ b/userapi/api/api_trace.go
@@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet
return err
}
-func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error {
- err := t.Impl.QueryNumericLocalpart(ctx, res)
+func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error {
+ err := t.Impl.QueryNumericLocalpart(ctx, req, res)
util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res))
return err
}
diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go
index 79f1bf06..42ae72e7 100644
--- a/userapi/consumers/clientapi.go
+++ b/userapi/consumers/clientapi.go
@@ -104,7 +104,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
return false
}
- updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
+ updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
if err != nil {
log.WithError(err).Error("userapi EDU consumer")
return false
@@ -118,7 +118,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
if !updated {
return true
}
- if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
+ if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, domain, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false
}
diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go
index b6b30a09..5d8924dd 100644
--- a/userapi/consumers/roomserver.go
+++ b/userapi/consumers/roomserver.go
@@ -192,25 +192,25 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy
func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error {
for _, membership := range localMembers {
// Copy any existing push rules from old -> new room
- if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
+ if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
return err
}
// preserve m.direct room state
- if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil {
+ if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil {
return err
}
// copy existing m.tag entries, if any
- if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
+ if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
return err
}
}
return nil
}
-func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error {
- pushRules, err := s.db.QueryPushRules(ctx, localpart)
+func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error {
+ pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName)
if err != nil {
return fmt.Errorf("failed to query pushrules for user: %w", err)
}
@@ -229,7 +229,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
if err != nil {
return err
}
- if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil {
+ if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil {
return fmt.Errorf("failed to update pushrules: %w", err)
}
}
@@ -237,13 +237,13 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
}
// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID
-func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error {
+func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error {
// this is most likely not a DM, so skip updating m.direct state
if roomSize > 2 {
return nil
}
// Get direct message state
- directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct")
+ directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct")
if err != nil {
return fmt.Errorf("failed to get m.direct from database: %w", err)
}
@@ -267,7 +267,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
if err != nil {
return true
}
- if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil {
+ if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil {
return true
}
}
@@ -279,15 +279,15 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
return nil
}
-func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error {
- tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag")
+func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error {
+ tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag")
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
if tag == nil {
return nil
}
- return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag)
+ return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag)
}
func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
@@ -492,11 +492,11 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er
func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error {
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
if err != nil {
- return err
+ return fmt.Errorf("s.evaluatePushRules: %w", err)
}
a, tweaks, err := pushrules.ActionsToTweaks(actions)
if err != nil {
- return err
+ return fmt.Errorf("pushrules.ActionsToTweaks: %w", err)
}
// TODO: support coalescing.
if a != pushrules.NotifyAction && a != pushrules.CoalesceAction {
@@ -508,9 +508,9 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
return nil
}
- devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks)
+ devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks)
if err != nil {
- return err
+ return fmt.Errorf("s.localPushDevices: %w", err)
}
n := &api.Notification{
@@ -527,18 +527,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
RoomID: event.RoomID(),
TS: gomatrixserverlib.AsTimestamp(time.Now()),
}
- if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil {
- return err
+ if err = s.db.InsertNotification(ctx, mem.Localpart, mem.Domain, event.EventID(), streamPos, tweaks, n); err != nil {
+ return fmt.Errorf("s.db.InsertNotification: %w", err)
}
if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil {
- return err
+ return fmt.Errorf("s.syncProducer.GetAndSendNotificationData: %w", err)
}
// We do this after InsertNotification. Thus, this should always return >=1.
- userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications)
+ userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, mem.Domain, tables.AllNotifications)
if err != nil {
- return err
+ return fmt.Errorf("s.db.GetNotificationCount: %w", err)
}
log.WithFields(log.Fields{
@@ -589,7 +589,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
}
if len(rejected) > 0 {
- s.deleteRejectedPushers(ctx, rejected, mem.Localpart)
+ s.deleteRejectedPushers(ctx, rejected, mem.Localpart, mem.Domain)
}
}()
@@ -606,7 +606,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
}
// Get accountdata to check if the event.Sender() is ignored by mem.LocalPart
- data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list")
+ data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, mem.Domain, "", "m.ignored_user_list")
if err != nil {
return nil, err
}
@@ -621,7 +621,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
return nil, fmt.Errorf("user %s is ignored", sender)
}
}
- ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart)
+ ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain)
if err != nil {
return nil, err
}
@@ -693,10 +693,10 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
// localPushDevices pushes to the configured devices of a local
// user. The map keys are [url][format].
-func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
- pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
+func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
+ pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db)
if err != nil {
- return nil, "", err
+ return nil, "", fmt.Errorf("util.GetPushDevices: %w", err)
}
var profileTag string
@@ -791,7 +791,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri
}
// deleteRejectedPushers deletes the pushers associated with the given devices.
-func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
+func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) {
log.WithFields(log.Fields{
"localpart": localpart,
"app_id0": devices[0].AppID,
@@ -799,7 +799,7 @@ func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, dev
}).Warnf("Deleting pushers rejected by the HTTP push gateway")
for _, d := range devices {
- if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil {
+ if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart, serverName); err != nil {
log.WithFields(log.Fields{
"localpart": localpart,
}).WithError(err).Errorf("Unable to delete rejected pusher")
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index 9ca76965..3f256457 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -68,7 +68,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" {
return fmt.Errorf("data type must not be empty")
}
- if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil {
+ if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil {
util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed")
return fmt.Errorf("failed to save account data: %w", err)
}
@@ -108,7 +108,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
return nil
}
- deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
+ deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
if err != nil {
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
return err
@@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
return nil
}
- if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil {
+ if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
return err
}
@@ -175,8 +175,10 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
if serverName == "" {
serverName = a.Config.Matrix.ServerName
}
- // XXXX: Use the server name here
- acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
+ if !a.Config.Matrix.IsLocalServerName(serverName) {
+ return fmt.Errorf("server name %s is not local", serverName)
+ }
+ acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict {
@@ -215,8 +217,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil
}
- if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
- return err
+ if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil {
+ return fmt.Errorf("a.DB.SetDisplayName: %w", err)
}
postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
@@ -227,11 +229,14 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
}
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
- if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
+ if !a.Config.Matrix.IsLocalServerName(req.ServerName) {
+ return fmt.Errorf("server name %s is not local", req.ServerName)
+ }
+ if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil {
return err
}
if req.LogoutDevices {
- if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil {
+ if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil {
return err
}
}
@@ -244,14 +249,15 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
if serverName == "" {
serverName = a.Config.Matrix.ServerName
}
- _ = serverName
- // XXXX: Use the server name here
+ if !a.Config.Matrix.IsLocalServerName(serverName) {
+ return fmt.Errorf("server name %s is not local", serverName)
+ }
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
"device_id": req.DeviceID,
"display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation")
- dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
+ dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil {
return err
}
@@ -276,12 +282,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
- devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
+ devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
} else {
- err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
+ err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs)
}
if err != nil {
return err
@@ -335,23 +341,29 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
req *api.PerformLastSeenUpdateRequest,
res *api.PerformLastSeenUpdateResponse,
) error {
- localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
+ localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
- if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
+ if !a.Config.Matrix.IsLocalServerName(domain) {
+ return fmt.Errorf("server name %s is not local", domain)
+ }
+ if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
}
return nil
}
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
- localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
+ localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err
}
- dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
+ if !a.Config.Matrix.IsLocalServerName(domain) {
+ return fmt.Errorf("server name %s is not local", domain)
+ }
+ dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID)
if err == sql.ErrNoRows {
res.DeviceExists = false
return nil
@@ -366,7 +378,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil
}
- err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
+ err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err
@@ -406,7 +418,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot query profile of remote users (server name %s)", domain)
}
- prof, err := a.DB.GetProfileByLocalpart(ctx, local)
+ prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain)
if err != nil {
if err == sql.ErrNoRows {
return nil
@@ -457,7 +469,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
}
- devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
+ devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain)
if err != nil {
return err
}
@@ -476,7 +488,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
if req.DataType != "" {
var data json.RawMessage
- data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
+ data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType)
if err != nil {
return err
}
@@ -494,7 +506,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
return nil
}
- global, rooms, err := a.DB.GetAccountData(ctx, local)
+ global, rooms, err := a.DB.GetAccountData(ctx, local, domain)
if err != nil {
return err
}
@@ -527,7 +539,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
if !a.Config.Matrix.IsLocalServerName(domain) {
return nil
}
- acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
+ acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain)
if err != nil {
return err
}
@@ -561,14 +573,14 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
AccountType: api.AccountTypeAppService,
}
- localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
+ localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
if err != nil {
return nil, err
}
if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered
- account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
+ account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain)
// Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@@ -620,7 +632,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
return err
}
- err := a.DB.DeactivateAccount(ctx, req.Localpart)
+ err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName)
res.AccountDeactivated = err == nil
return err
}
@@ -783,7 +795,7 @@ func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.Query
if req.Only == "highlight" {
filter = tables.HighlightNotifications
}
- notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter)
+ notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter)
if err != nil {
return err
}
@@ -811,23 +823,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
}
}
if req.Pusher.Kind == "" {
- return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
+ return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName)
}
if req.Pusher.PushKeyTS == 0 {
req.Pusher.PushKeyTS = int64(time.Now().Unix())
}
- return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
+ return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName)
}
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
- pushers, err := a.DB.GetPushers(ctx, req.Localpart)
+ pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
if err != nil {
return err
}
for i := range pushers {
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
if pushers[i].SessionID != req.SessionID {
- err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
+ err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName)
if err != nil {
return err
}
@@ -838,7 +850,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
var err error
- res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
+ res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
return err
}
@@ -864,11 +876,11 @@ func (a *UserInternalAPI) PerformPushRulesPut(
}
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
- localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
+ localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
}
- pushRules, err := a.DB.QueryPushRules(ctx, localpart)
+ pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain)
if err != nil {
return fmt.Errorf("failed to query push rules: %w", err)
}
@@ -877,14 +889,14 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
}
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
- profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
+ profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL)
res.Profile = profile
res.Changed = changed
return err
}
-func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
- id, err := a.DB.GetNewNumericLocalpart(ctx)
+func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error {
+ id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName)
if err != nil {
return err
}
@@ -894,12 +906,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
var err error
- res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart)
+ res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName)
return err
}
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
- acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword)
+ acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword)
switch err {
case sql.ErrNoRows: // user does not exist
return nil
@@ -915,23 +927,24 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
}
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
- profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
+ profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName)
res.Profile = profile
res.Changed = changed
return err
}
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
- localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
+ localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
if err != nil {
return err
}
res.Localpart = localpart
+ res.ServerName = domain
return nil
}
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
- r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart)
+ r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName)
if err != nil {
return err
}
@@ -944,7 +957,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe
}
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
- return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium)
+ return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
}
const pushRulesAccountDataType = "m.push_rules"
diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go
index 87f25e5e..3b211db5 100644
--- a/userapi/internal/api_logintoken.go
+++ b/userapi/internal/api_logintoken.go
@@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
}
- if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
+ if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil {
res.Data = nil
if err == sql.ErrNoRows {
return nil
diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go
index aa5d46d9..87ae058c 100644
--- a/userapi/inthttp/client.go
+++ b/userapi/inthttp/client.go
@@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL(
func (h *httpUserInternalAPI) QueryNumericLocalpart(
ctx context.Context,
+ request *api.QueryNumericLocalpartRequest,
response *api.QueryNumericLocalpartResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
- h.httpClient, ctx, &struct{}{}, response,
+ h.httpClient, ctx, request, response,
)
}
diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go
index 99148b76..661fecfa 100644
--- a/userapi/inthttp/server.go
+++ b/userapi/inthttp/server.go
@@ -15,12 +15,9 @@
package inthttp
import (
- "net/http"
-
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/util"
)
// nolint: gocyclo
@@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
)
- // TODO: Look at the shape of this
- internalAPIMux.Handle(QueryNumericLocalpartPath,
- httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse {
- response := api.QueryNumericLocalpartResponse{}
- if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
- return util.ErrorResponse(err)
- }
- return util.JSONResponse{Code: http.StatusOK, JSON: &response}
- }),
+ internalAPIMux.Handle(
+ QueryNumericLocalpartPath,
+ httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart),
)
internalAPIMux.Handle(
diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go
index f556ea35..51eaa985 100644
--- a/userapi/producers/syncapi.go
+++ b/userapi/producers/syncapi.go
@@ -61,12 +61,12 @@ func (p *SyncAPI) SendAccountData(userID string, data eventutil.AccountData) err
// GetAndSendNotificationData reads the database and sends data about unread
// notifications to the Sync API server.
func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
- localpart, _, err := gomatrixserverlib.SplitID('@', userID)
+ localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return err
}
- ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID)
+ ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, domain, roomID)
if err != nil {
return err
}
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go
index 28ef2655..c22b7658 100644
--- a/userapi/storage/interface.go
+++ b/userapi/storage/interface.go
@@ -29,40 +29,40 @@ import (
)
type Profile interface {
- GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
+ GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
- SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
- SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error)
+ SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
+ SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
}
type Account interface {
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists.
- CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
- GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
- GetNewNumericLocalpart(ctx context.Context) (int64, error)
- CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
- GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
- DeactivateAccount(ctx context.Context, localpart string) (err error)
- SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
+ CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
+ GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error)
+ GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
+ CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error)
+ GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
+ DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
+ SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error
}
type AccountData interface {
- SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
- GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
+ SaveAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
+ GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
- GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
- QueryPushRules(ctx context.Context, localpart string) (*pushrules.AccountRuleSets, error)
+ GetAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
+ QueryPushRules(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*pushrules.AccountRuleSets, error)
}
type Device interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
- GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
- GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
+ GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
+ GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
@@ -70,12 +70,12 @@ type Device interface {
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
- CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
- UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
- UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error
- RemoveDevices(ctx context.Context, localpart string, devices []string) error
+ CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
+ UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
+ UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
+ RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
- RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
+ RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error)
}
type KeyBackup interface {
@@ -107,26 +107,26 @@ type OpenID interface {
}
type Pusher interface {
- UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
- GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
- RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
+ UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error
+ GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
+ RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
RemovePushers(ctx context.Context, appid, pushkey string) error
}
type ThreePID interface {
- SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
+ SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
- GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
- GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
+ GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
+ GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
}
type Notification interface {
- InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
- DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error)
- SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error)
- GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
- GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
- GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
+ InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
+ DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error)
+ SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, read bool) (affected bool, err error)
+ GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
+ GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error)
+ GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
DeleteOldNotifications(ctx context.Context) error
}
diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go
index 0b6a3af6..2a4777d7 100644
--- a/userapi/storage/postgres/account_data_table.go
+++ b/userapi/storage/postgres/account_data_table.go
@@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
)
const accountDataSchema = `
@@ -29,27 +30,28 @@ const accountDataSchema = `
CREATE TABLE IF NOT EXISTS userapi_account_datas (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room)
room_id TEXT,
-- The account data type
type TEXT NOT NULL,
-- The account data content
- content TEXT NOT NULL,
-
- PRIMARY KEY(localpart, room_id, type)
+ content TEXT NOT NULL
);
+
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
`
const insertAccountDataSQL = `
- INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
- ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
+ INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
+ ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = EXCLUDED.content
`
const selectAccountDataSQL = "" +
- "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
+ "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
const selectAccountDataByTypeSQL = "" +
- "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
+ "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
type accountDataStatements struct {
insertAccountDataStmt *sql.Stmt
@@ -71,21 +73,24 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
}
func (s *accountDataStatements) InsertAccountData(
- ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ roomID, dataType string, content json.RawMessage,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
- _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
+ _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
return
}
func (s *accountDataStatements) SelectAccountData(
- ctx context.Context, localpart string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (
/* global */ map[string]json.RawMessage,
/* rooms */ map[string]map[string]json.RawMessage,
error,
) {
- rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
+ rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return nil, nil, err
}
@@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
}
func (s *accountDataStatements) SelectAccountDataByType(
- ctx context.Context, localpart, roomID, dataType string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ roomID, dataType string,
) (data json.RawMessage, err error) {
var bytes []byte
stmt := s.selectAccountDataByTypeStmt
- if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
+ if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go
index 7c309eb4..31a99652 100644
--- a/userapi/storage/postgres/accounts_table.go
+++ b/userapi/storage/postgres/accounts_table.go
@@ -17,6 +17,7 @@ package postgres
import (
"context"
"database/sql"
+ "fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
@@ -34,7 +35,8 @@ const accountsSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS userapi_accounts (
-- The Matrix user ID localpart for this account
- localpart TEXT NOT NULL PRIMARY KEY,
+ localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
-- When this account was first created, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account.
@@ -48,25 +50,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
-- TODO:
-- upgraded_ts, devices, any email reset stuff?
);
+
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
`
const insertAccountSQL = "" +
- "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
+ "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" +
- "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
+ "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
const deactivateAccountSQL = "" +
- "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1"
+ "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2"
const selectAccountByLocalpartSQL = "" +
- "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
+ "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
const selectPasswordHashSQL = "" +
- "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
+ "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
const selectNewNumericLocalpartSQL = "" +
- "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'"
+ "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $1"
type accountsStatements struct {
insertAccountStmt *sql.Stmt
@@ -117,59 +121,62 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) InsertAccount(
- ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error
if accountType != api.AccountTypeAppService {
- _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
+ _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
} else {
- _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
+ _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
}
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("insertAccountStmt: %w", err)
}
return &api.Account{
Localpart: localpart,
- UserID: userutil.MakeUserID(localpart, s.serverName),
- ServerName: s.serverName,
+ UserID: userutil.MakeUserID(localpart, serverName),
+ ServerName: serverName,
AppServiceID: appserviceID,
AccountType: accountType,
}, nil
}
func (s *accountsStatements) UpdatePassword(
- ctx context.Context, localpart, passwordHash string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
+ passwordHash string,
) (err error) {
- _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
+ _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
return
}
func (s *accountsStatements) DeactivateAccount(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
- _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
+ _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
return
}
func (s *accountsStatements) SelectPasswordHash(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (hash string, err error) {
- err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
+ err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
return
}
func (s *accountsStatements) SelectAccountByLocalpart(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
var acc api.Account
stmt := s.selectAccountByLocalpartStmt
- err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
+ err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
@@ -180,19 +187,17 @@ func (s *accountsStatements) SelectAccountByLocalpart(
acc.AppServiceID = appserviceIDPtr.String
}
- acc.UserID = userutil.MakeUserID(localpart, s.serverName)
- acc.ServerName = s.serverName
-
+ acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
return &acc, nil
}
func (s *accountsStatements) SelectNewNumericLocalpart(
- ctx context.Context, txn *sql.Tx,
+ ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
- err = stmt.QueryRowContext(ctx).Scan(&id)
+ err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
return id + 1, err
}
diff --git a/userapi/storage/postgres/deltas/2022110411000000_server_names.go b/userapi/storage/postgres/deltas/2022110411000000_server_names.go
new file mode 100644
index 00000000..279e1e5f
--- /dev/null
+++ b/userapi/storage/postgres/deltas/2022110411000000_server_names.go
@@ -0,0 +1,81 @@
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var serverNamesTables = []string{
+ "userapi_accounts",
+ "userapi_account_datas",
+ "userapi_devices",
+ "userapi_notifications",
+ "userapi_openid_tokens",
+ "userapi_profiles",
+ "userapi_pushers",
+ "userapi_threepids",
+}
+
+// These tables have a PRIMARY KEY constraint which we need to drop so
+// that we can recreate a new unique index that contains the server name.
+// If the new key doesn't exist (i.e. the database was created before the
+// table rename migration) we'll try to drop the old one instead.
+var serverNamesDropPK = map[string]string{
+ "userapi_accounts": "account_accounts",
+ "userapi_account_datas": "account_data",
+ "userapi_profiles": "account_profiles",
+}
+
+// These indices are out of date so let's drop them. They will get recreated
+// automatically.
+var serverNamesDropIndex = []string{
+ "userapi_pusher_localpart_idx",
+ "userapi_pusher_app_id_pushkey_localpart_idx",
+}
+
+// I know what you're thinking: you're wondering "why doesn't this use $1
+// and pass variadic parameters to ExecContext?" — the answer is because
+// PostgreSQL doesn't expect the table name to be specified as a substituted
+// argument in that way so it results in a syntax error in the query.
+
+func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
+ for _, table := range serverNamesTables {
+ q := fmt.Sprintf(
+ "ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';",
+ pq.QuoteIdentifier(table),
+ )
+ if _, err := tx.ExecContext(ctx, q); err != nil {
+ return fmt.Errorf("add server name to %q error: %w", table, err)
+ }
+ }
+ for newTable, oldTable := range serverNamesDropPK {
+ q := fmt.Sprintf(
+ "ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;",
+ pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(newTable+"_pkey"),
+ )
+ if _, err := tx.ExecContext(ctx, q); err != nil {
+ return fmt.Errorf("drop new PK from %q error: %w", newTable, err)
+ }
+ q = fmt.Sprintf(
+ "ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;",
+ pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(oldTable+"_pkey"),
+ )
+ if _, err := tx.ExecContext(ctx, q); err != nil {
+ return fmt.Errorf("drop old PK from %q error: %w", newTable, err)
+ }
+ }
+ for _, index := range serverNamesDropIndex {
+ q := fmt.Sprintf(
+ "DROP INDEX IF EXISTS %s;",
+ pq.QuoteIdentifier(index),
+ )
+ if _, err := tx.ExecContext(ctx, q); err != nil {
+ return fmt.Errorf("drop index %q error: %w", index, err)
+ }
+ }
+ return nil
+}
diff --git a/userapi/storage/postgres/deltas/2022110411000001_server_names.go b/userapi/storage/postgres/deltas/2022110411000001_server_names.go
new file mode 100644
index 00000000..04a47fa7
--- /dev/null
+++ b/userapi/storage/postgres/deltas/2022110411000001_server_names.go
@@ -0,0 +1,28 @@
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// I know what you're thinking: you're wondering "why doesn't this use $1
+// and pass variadic parameters to ExecContext?" — the answer is because
+// PostgreSQL doesn't expect the table name to be specified as a substituted
+// argument in that way so it results in a syntax error in the query.
+
+func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
+ for _, table := range serverNamesTables {
+ q := fmt.Sprintf(
+ "UPDATE %s SET server_name = %s WHERE server_name = '';",
+ pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
+ )
+ if _, err := tx.ExecContext(ctx, q); err != nil {
+ return fmt.Errorf("write server names to %q error: %w", table, err)
+ }
+ }
+ return nil
+}
diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go
index 8b7fbd6c..2dd21618 100644
--- a/userapi/storage/postgres/devices_table.go
+++ b/userapi/storage/postgres/devices_table.go
@@ -17,6 +17,7 @@ package postgres
import (
"context"
"database/sql"
+ "fmt"
"time"
"github.com/lib/pq"
@@ -50,6 +51,7 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
-- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
-- migration to different domain names easier.
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL,
-- The display name, human friendlier than device_id and updatable
@@ -65,39 +67,39 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
);
-- Device IDs must be unique for a given user.
-CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id);
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, server_name, device_id);
`
const insertDeviceSQL = "" +
- "INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
+ "INSERT INTO userapi_devices(device_id, localpart, server_name, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
" RETURNING session_id"
const selectDeviceByTokenSQL = "" +
- "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
+ "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
- "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
+ "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
const selectDevicesByLocalpartSQL = "" +
- "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
+ "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
- "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
+ "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
const deleteDeviceSQL = "" +
- "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
+ "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
const deleteDevicesByLocalpartSQL = "" +
- "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
+ "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
const deleteDevicesSQL = "" +
- "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)"
+ "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)"
const selectDevicesByIDSQL = "" +
- "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
+ "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
- "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
+ "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
type devicesStatements struct {
insertDeviceStmt *sql.Stmt
@@ -148,18 +150,19 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) InsertDevice(
- ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
- displayName *string, ipAddr, userAgent string,
+ ctx context.Context, txn *sql.Tx, id string,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ accessToken string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
- if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
- return nil, err
+ if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
+ return nil, fmt.Errorf("insertDeviceStmt: %w", err)
}
return &api.Device{
ID: id,
- UserID: userutil.MakeUserID(localpart, s.serverName),
+ UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
SessionID: sessionID,
LastSeenTS: createdTimeMS,
@@ -170,38 +173,45 @@ func (s *devicesStatements) InsertDevice(
// deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) DeleteDevice(
- ctx context.Context, txn *sql.Tx, id, localpart string,
+ ctx context.Context, txn *sql.Tx, id string,
+ localpart string, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
- _, err := stmt.ExecContext(ctx, id, localpart)
+ _, err := stmt.ExecContext(ctx, id, localpart, serverName)
return err
}
// deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed.
func (s *devicesStatements) DeleteDevices(
- ctx context.Context, txn *sql.Tx, localpart string, devices []string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ devices []string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
- _, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
+ _, err := stmt.ExecContext(ctx, localpart, serverName, pq.Array(devices))
return err
}
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) DeleteDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
- _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
+ _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
return err
}
func (s *devicesStatements) UpdateDeviceName(
- ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
- _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
+ _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
return err
}
@@ -210,10 +220,11 @@ func (s *devicesStatements) SelectDeviceByToken(
) (*api.Device, error) {
var dev api.Device
var localpart string
+ var serverName gomatrixserverlib.ServerName
stmt := s.selectDeviceByTokenStmt
- err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
+ err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
if err == nil {
- dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ dev.UserID = userutil.MakeUserID(localpart, serverName)
dev.AccessToken = accessToken
}
return &dev, err
@@ -222,16 +233,18 @@ func (s *devicesStatements) SelectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) SelectDeviceByID(
- ctx context.Context, localpart, deviceID string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ deviceID string,
) (*api.Device, error) {
var dev api.Device
var displayName, ip sql.NullString
var lastseenTS sql.NullInt64
stmt := s.selectDeviceByIDStmt
- err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
+ err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil {
dev.ID = deviceID
- dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ dev.UserID = userutil.MakeUserID(localpart, serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
@@ -254,10 +267,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
var devices []api.Device
var dev api.Device
var localpart string
+ var serverName gomatrixserverlib.ServerName
var lastseents sql.NullInt64
var displayName sql.NullString
for rows.Next() {
- if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
+ if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
return nil, err
}
if displayName.Valid {
@@ -266,17 +280,19 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
if lastseents.Valid {
dev.LastSeenTS = lastseents.Int64
}
- dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
func (s *devicesStatements) SelectDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
- rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
+ rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
if err != nil {
return devices, err
@@ -307,16 +323,16 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
dev.UserAgent = useragent.String
}
- dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
-func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
+func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
- _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
+ _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
return err
}
diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go
index 24a30b2f..dc64b1e7 100644
--- a/userapi/storage/postgres/notifications_table.go
+++ b/userapi/storage/postgres/notifications_table.go
@@ -43,6 +43,7 @@ const notificationSchema = `
CREATE TABLE IF NOT EXISTS userapi_notifications (
id BIGSERIAL PRIMARY KEY,
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
stream_pos BIGINT NOT NULL,
@@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
read BOOLEAN NOT NULL DEFAULT FALSE
);
-CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
-CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
-CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
`
const insertNotificationSQL = "" +
- "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
+ "INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const deleteNotificationsUpToSQL = "" +
- "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
+ "DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
const updateNotificationReadSQL = "" +
- "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
+ "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
const selectNotificationSQL = "" +
- "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
- "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
- ") AND NOT read ORDER BY localpart, id LIMIT $4"
+ "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
+ "(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read ORDER BY localpart, id LIMIT $5"
const selectNotificationCountSQL = "" +
- "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
- "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
+ "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
+ "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
") AND NOT read"
const selectRoomNotificationCountsSQL = "" +
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
- "WHERE localpart = $1 AND room_id = $2 AND NOT read"
+ "WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
const cleanNotificationsSQL = "" +
"DELETE FROM userapi_notifications WHERE" +
@@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
}
// Insert inserts a notification into the database.
-func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
+func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS
nn := *n
// Clears out fields that have their own columns to (1) shrink the
@@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
if err != nil {
return err
}
- _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
+ _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
return err
}
// DeleteUpTo deletes all previous notifications, up to and including the event.
-func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
- res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
+func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
if err != nil {
return false, err
}
@@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
}
// UpdateRead updates the "read" value for an event.
-func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
- res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
+func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
if err != nil {
return false, err
}
@@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
return nrows > 0, nil
}
-func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
- rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
+func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
if err != nil {
return nil, 0, err
@@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err()
}
-func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
- err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
+func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
+ err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
return
}
-func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
- err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
+func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
+ err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
return
}
diff --git a/userapi/storage/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go
index 06ae30d0..68d87f00 100644
--- a/userapi/storage/postgres/openid_table.go
+++ b/userapi/storage/postgres/openid_table.go
@@ -3,6 +3,7 @@ package postgres
import (
"context"
"database/sql"
+ "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
@@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_at_ms BIGINT NOT NULL
);
`
const insertOpenIDTokenSQL = "" +
- "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
+ "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
const selectOpenIDTokenSQL = "" +
- "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
+ "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct {
insertTokenStmt *sql.Stmt
@@ -54,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context,
txn *sql.Tx,
- token, localpart string,
+ token, localpart string, serverName gomatrixserverlib.ServerName,
expiresAtMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
- _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
+ _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
return
}
@@ -69,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
token string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
+ var localpart string
+ var serverName gomatrixserverlib.ServerName
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
- &openIDTokenAttrs.UserID,
+ &localpart, &serverName,
&openIDTokenAttrs.ExpiresAtMS,
)
+ openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db")
diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go
index 2753b23d..df4e0db6 100644
--- a/userapi/storage/postgres/profile_table.go
+++ b/userapi/storage/postgres/profile_table.go
@@ -23,42 +23,46 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
)
const profilesSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS userapi_profiles (
-- The Matrix user ID localpart for this account
- localpart TEXT NOT NULL PRIMARY KEY,
+ localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
-- The display name for this account
display_name TEXT,
-- The URL of the avatar for this account
avatar_url TEXT
);
+
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
`
const insertProfileSQL = "" +
- "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
+ "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
const selectProfileByLocalpartSQL = "" +
- "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
+ "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
const setAvatarURLSQL = "" +
"UPDATE userapi_profiles AS new" +
" SET avatar_url = $1" +
" FROM userapi_profiles AS old" +
- " WHERE new.localpart = $2" +
+ " WHERE new.localpart = $2 AND new.server_name = $3" +
" RETURNING new.display_name, old.avatar_url <> new.avatar_url"
const setDisplayNameSQL = "" +
"UPDATE userapi_profiles AS new" +
" SET display_name = $1" +
" FROM userapi_profiles AS old" +
- " WHERE new.localpart = $2" +
+ " WHERE new.localpart = $2 AND new.server_name = $3" +
" RETURNING new.avatar_url, old.display_name <> new.display_name"
const selectProfilesBySearchSQL = "" +
- "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
+ "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct {
serverNoticesLocalpart string
@@ -87,18 +91,20 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables
}
func (s *profilesStatements) InsertProfile(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
- _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
+ _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
return
}
func (s *profilesStatements) SelectProfileByLocalpart(
- ctx context.Context, localpart string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (*authtypes.Profile, error) {
var profile authtypes.Profile
- err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
- &profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
+ err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
+ &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
)
if err != nil {
return nil, err
@@ -107,28 +113,34 @@ func (s *profilesStatements) SelectProfileByLocalpart(
}
func (s *profilesStatements) SetAvatarURL(
- ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ avatarURL string,
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
- Localpart: localpart,
- AvatarURL: avatarURL,
+ Localpart: localpart,
+ ServerName: string(serverName),
+ AvatarURL: avatarURL,
}
var changed bool
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
- err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed)
+ err := stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName, &changed)
return profile, changed, err
}
func (s *profilesStatements) SetDisplayName(
- ctx context.Context, txn *sql.Tx, localpart string, displayName string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ displayName string,
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
+ ServerName: string(serverName),
DisplayName: displayName,
}
var changed bool
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
- err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed)
+ err := stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL, &changed)
return profile, changed, err
}
@@ -146,7 +158,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
for rows.Next() {
var profile authtypes.Profile
- if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
+ if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err
}
if profile.Localpart != s.serverNoticesLocalpart {
diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go
index 6fb714fb..707b3bd2 100644
--- a/userapi/storage/postgres/pusher_table.go
+++ b/userapi/storage/postgres/pusher_table.go
@@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
)
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
id BIGSERIAL PRIMARY KEY,
-- The Matrix user ID localpart for this pusher
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
session_id BIGINT DEFAULT NULL,
profile_tag TEXT,
kind TEXT NOT NULL,
@@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart.
-CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
+CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
-- Pushkey must be unique for a given user and app.
-CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
`
const insertPusherSQL = "" +
- "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
- "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
- "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
+ "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
+ "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
+ "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
const selectPushersSQL = "" +
- "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
+ "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
const deletePusherSQL = "" +
- "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
@@ -95,18 +97,19 @@ type pushersStatements struct {
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
- pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
+ pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
+ localpart string, serverName gomatrixserverlib.ServerName,
) error {
- _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
- logrus.Debugf("Created pusher %d", session_id)
+ _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
return err
}
func (s *pushersStatements) SelectPushers(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) {
pushers := []api.Pusher{}
- rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
+ rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart, serverName)
if err != nil {
return pushers, err
@@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
// deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher(
- ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
+ ctx context.Context, txn *sql.Tx, appid, pushkey,
+ localpart string, serverName gomatrixserverlib.ServerName,
) error {
- _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
+ _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
return err
}
diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go
index c059e3e6..92dc4808 100644
--- a/userapi/storage/postgres/storage.go
+++ b/userapi/storage/postgres/storage.go
@@ -15,6 +15,8 @@
package postgres
import (
+ "context"
+ "database/sql"
"fmt"
"time"
@@ -43,18 +45,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables,
})
+ m.AddMigrations(sqlutil.Migration{
+ Version: "userapi: server names",
+ Up: func(ctx context.Context, txn *sql.Tx) error {
+ return deltas.UpServerNames(ctx, txn, serverName)
+ },
+ })
if err = m.Up(base.Context()); err != nil {
return nil, err
}
- accountDataTable, err := NewPostgresAccountDataTable(db)
- if err != nil {
- return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
- }
accountsTable, err := NewPostgresAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
}
+ accountDataTable, err := NewPostgresAccountDataTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
+ }
devicesTable, err := NewPostgresDevicesTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
@@ -95,6 +103,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
}
+
+ m = sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "userapi: server names populate",
+ Up: func(ctx context.Context, txn *sql.Tx) error {
+ return deltas.UpServerNamesPopulate(ctx, txn, serverName)
+ },
+ })
+ if err = m.Up(base.Context()); err != nil {
+ return nil, err
+ }
+
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go
index 11af7616..f41c4312 100644
--- a/userapi/storage/postgres/threepid_table.go
+++ b/userapi/storage/postgres/threepid_table.go
@@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
@@ -33,21 +34,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
medium TEXT NOT NULL DEFAULT 'email',
-- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
PRIMARY KEY(threepid, medium)
);
-CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart);
+CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart, server_name);
`
const selectLocalpartForThreePIDSQL = "" +
- "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
+ "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
- "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
+ "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
const insertThreePIDSQL = "" +
- "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
+ "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
const deleteThreePIDSQL = "" +
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
@@ -75,19 +77,20 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
-) (localpart string, err error) {
+) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
- err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
+ err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
if err == sql.ErrNoRows {
- return "", nil
+ return "", "", nil
}
return
}
func (s *threepidStatements) SelectThreePIDsForLocalpart(
- ctx context.Context, localpart string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) {
- rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
+ rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return
}
@@ -109,10 +112,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
}
func (s *threepidStatements) InsertThreePID(
- ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
+ ctx context.Context, txn *sql.Tx, threepid, medium,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
- _, err = stmt.ExecContext(ctx, threepid, medium, localpart)
+ _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
return
}
diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go
index f8b8d02c..f549dcef 100644
--- a/userapi/storage/shared/storage.go
+++ b/userapi/storage/shared/storage.go
@@ -68,9 +68,10 @@ const (
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
- ctx context.Context, localpart, plaintextPassword string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
+ plaintextPassword string,
) (*api.Account, error) {
- hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
+ hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName)
if err != nil {
return nil, err
}
@@ -80,24 +81,27 @@ func (d *Database) GetAccountByPassword(
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
- return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
+ return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
- ctx context.Context, localpart string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (*authtypes.Profile, error) {
- return d.Profiles.SelectProfileByLocalpart(ctx, localpart)
+ return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
- ctx context.Context, localpart string, avatarURL string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ avatarURL string,
) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
+ profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL)
return err
})
return
@@ -106,10 +110,12 @@ func (d *Database) SetAvatarURL(
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
- ctx context.Context, localpart string, displayName string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ displayName string,
) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
+ profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName)
return err
})
return
@@ -117,14 +123,15 @@ func (d *Database) SetDisplayName(
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
- ctx context.Context, localpart, plaintextPassword string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
+ plaintextPassword string,
) error {
hash, err := d.hashPassword(plaintextPassword)
if err != nil {
return err
}
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
- return d.Accounts.UpdatePassword(ctx, localpart, hash)
+ return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash)
})
}
@@ -132,21 +139,22 @@ func (d *Database) SetPassword(
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists.
func (d *Database) CreateAccount(
- ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
+ plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest {
var numLocalpart int64
- numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
+ numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName)
if err != nil {
- return err
+ return fmt.Errorf("d.Accounts.SelectNewNumericLocalpart: %w", err)
}
localpart = strconv.FormatInt(numLocalpart, 10)
plaintextPassword = ""
appserviceID = ""
}
- acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
+ acc, err = d.createAccount(ctx, txn, localpart, serverName, plaintextPassword, appserviceID, accountType)
return err
})
return
@@ -155,7 +163,9 @@ func (d *Database) CreateAccount(
// WARNING! This function assumes that the relevant mutexes have already
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount(
- ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ plaintextPassword, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
var err error
var account *api.Account
@@ -167,28 +177,28 @@ func (d *Database) createAccount(
return nil, err
}
}
- if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
+ if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil {
return nil, sqlutil.ErrUserExists
}
- if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
- return nil, err
+ if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil {
+ return nil, fmt.Errorf("d.Profiles.InsertProfile: %w", err)
}
- pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
+ pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
prbs, err := json.Marshal(pushRuleSets)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("json.Marshal: %w", err)
}
- if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
- return nil, err
+ if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
+ return nil, fmt.Errorf("d.AccountDatas.InsertAccountData: %w", err)
}
return account, nil
}
func (d *Database) QueryPushRules(
ctx context.Context,
- localpart string,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (*pushrules.AccountRuleSets, error) {
- data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, "", "m.push_rules")
+ data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules")
if err != nil {
return nil, err
}
@@ -196,13 +206,13 @@ func (d *Database) QueryPushRules(
// If we didn't find any default push rules then we should just generate some
// fresh ones.
if len(data) == 0 {
- pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
+ pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
prbs, err := json.Marshal(pushRuleSets)
if err != nil {
return nil, fmt.Errorf("failed to marshal default push rules: %w", err)
}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", prbs); dbErr != nil {
+ if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", prbs); dbErr != nil {
return fmt.Errorf("failed to save default push rules: %w", dbErr)
}
return nil
@@ -225,22 +235,23 @@ func (d *Database) QueryPushRules(
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
- ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
+ roomID, dataType string, content json.RawMessage,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content)
+ return d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, roomID, dataType, content)
})
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
-func (d *Database) GetAccountData(ctx context.Context, localpart string) (
+func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (
global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage,
err error,
) {
- return d.AccountDatas.SelectAccountData(ctx, localpart)
+ return d.AccountDatas.SelectAccountData(ctx, localpart, serverName)
}
// GetAccountDataByType returns account data matching a given
@@ -248,18 +259,19 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
- ctx context.Context, localpart, roomID, dataType string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
+ roomID, dataType string,
) (data json.RawMessage, err error) {
return d.AccountDatas.SelectAccountDataByType(
- ctx, localpart, roomID, dataType,
+ ctx, localpart, serverName, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
- ctx context.Context,
+ ctx context.Context, serverName gomatrixserverlib.ServerName,
) (int64, error) {
- return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
+ return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName)
}
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
@@ -276,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
- ctx context.Context, threepid, localpart, medium string,
+ ctx context.Context, threepid string,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ medium string,
) (err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- user, err := d.ThreePIDs.SelectLocalpartForThreePID(
+ user, _, err := d.ThreePIDs.SelectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
@@ -290,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation(
return Err3PIDInUse
}
- return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart)
+ return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, serverName)
})
}
@@ -313,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation(
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
-) (localpart string, err error) {
+) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
}
@@ -322,16 +336,17 @@ func (d *Database) GetLocalpartForThreePID(
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
- ctx context.Context, localpart string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) {
- return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart)
+ return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
-func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
- _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
+func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) {
+ _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
if err == sql.ErrNoRows {
return true, nil
}
@@ -341,12 +356,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
-func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
+func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) {
// try to get the account with lowercase localpart (majority)
- acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart))
+ acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName)
if err == sql.ErrNoRows {
- acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request
+ acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request
}
return acc, err
}
@@ -359,20 +374,24 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
}
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
-func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
+func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
- return d.Accounts.DeactivateAccount(ctx, localpart)
+ return d.Accounts.DeactivateAccount(ctx, localpart, serverName)
})
}
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
- token, localpart string,
+ token, userID string,
) (int64, error) {
+ localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return 0, nil
+ }
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
- err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS)
})
return expiresAtMS, err
}
@@ -539,16 +558,19 @@ func (d *Database) GetDeviceByAccessToken(
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
- ctx context.Context, localpart, deviceID string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ deviceID string,
) (*api.Device, error) {
- return d.Devices.SelectDeviceByID(ctx, localpart, deviceID)
+ return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
- ctx context.Context, localpart string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Device, error) {
- return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "")
+ return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@@ -562,18 +584,18 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
- ctx context.Context, localpart string, deviceID *string, accessToken string,
- displayName *string, ipAddr, userAgent string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
+ deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
- if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil {
+ if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
return err
}
- dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
+ dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
@@ -588,7 +610,7 @@ func (d *Database) CreateDevice(
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
- dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
+ dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
@@ -614,10 +636,12 @@ func generateDeviceID() (string, error) {
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
- ctx context.Context, localpart, deviceID string, displayName *string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ deviceID string, displayName *string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName)
+ return d.Devices.UpdateDeviceName(ctx, txn, localpart, serverName, deviceID, displayName)
})
}
@@ -626,10 +650,12 @@ func (d *Database) UpdateDevice(
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
- ctx context.Context, localpart string, devices []string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ devices []string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
+ if err := d.Devices.DeleteDevices(ctx, txn, localpart, serverName, devices); err != sql.ErrNoRows {
return err
}
return nil
@@ -640,14 +666,16 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
- ctx context.Context, localpart, exceptDeviceID string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
+ devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID)
if err != nil {
return err
}
- if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
+ if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
@@ -656,9 +684,9 @@ func (d *Database) RemoveAllDevices(
}
// UpdateDeviceLastSeen updates a last seen timestamp and the ip address.
-func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error {
+func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent)
+ return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent)
})
}
@@ -706,38 +734,38 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
return d.LoginTokens.SelectLoginToken(ctx, token)
}
-func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
+func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
+ return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
})
}
-func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) {
+func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
+ affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos)
return err
})
return
}
-func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) {
+func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
+ affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b)
return err
})
return
}
-func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
- return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter)
+func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
+ return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter)
}
-func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
- return d.Notifications.SelectCount(ctx, nil, localpart, filter)
+func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) {
+ return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter)
}
-func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
- return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID)
+func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) {
+ return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID)
}
func (d *Database) DeleteOldNotifications(ctx context.Context) error {
@@ -747,7 +775,8 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error {
}
func (d *Database) UpsertPusher(
- ctx context.Context, p api.Pusher, localpart string,
+ ctx context.Context, p api.Pusher,
+ localpart string, serverName gomatrixserverlib.ServerName,
) error {
data, err := json.Marshal(p.Data)
if err != nil {
@@ -766,25 +795,26 @@ func (d *Database) UpsertPusher(
p.ProfileTag,
p.Language,
string(data),
- localpart)
+ localpart,
+ serverName)
})
}
// GetPushers returns the pushers matching the given localpart.
func (d *Database) GetPushers(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) {
- return d.Pushers.SelectPushers(ctx, nil, localpart)
+ return d.Pushers.SelectPushers(ctx, nil, localpart, serverName)
}
// RemovePusher deletes one pusher
// Invoked when `append` is true and `kind` is null in
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
func (d *Database) RemovePusher(
- ctx context.Context, appid, pushkey, localpart string,
+ ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
+ err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName)
if err == sql.ErrNoRows {
return nil
}
diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go
index af12decb..2fbdc573 100644
--- a/userapi/storage/sqlite3/account_data_table.go
+++ b/userapi/storage/sqlite3/account_data_table.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
)
const accountDataSchema = `
@@ -28,27 +29,28 @@ const accountDataSchema = `
CREATE TABLE IF NOT EXISTS userapi_account_datas (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room)
room_id TEXT,
-- The account data type
type TEXT NOT NULL,
-- The account data content
- content TEXT NOT NULL,
-
- PRIMARY KEY(localpart, room_id, type)
+ content TEXT NOT NULL
);
+
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
`
const insertAccountDataSQL = `
- INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
- ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
+ INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
+ ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = $5
`
const selectAccountDataSQL = "" +
- "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
+ "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
const selectAccountDataByTypeSQL = "" +
- "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
+ "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
type accountDataStatements struct {
db *sql.DB
@@ -73,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
}
func (s *accountDataStatements) InsertAccountData(
- ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ roomID, dataType string, content json.RawMessage,
) error {
- _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
+ _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content)
return err
}
func (s *accountDataStatements) SelectAccountData(
- ctx context.Context, localpart string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (
/* global */ map[string]json.RawMessage,
/* rooms */ map[string]map[string]json.RawMessage,
error,
) {
- rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
+ rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return nil, nil, err
}
@@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
}
func (s *accountDataStatements) SelectAccountDataByType(
- ctx context.Context, localpart, roomID, dataType string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ roomID, dataType string,
) (data json.RawMessage, err error) {
var bytes []byte
stmt := s.selectAccountDataByTypeStmt
- if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
+ if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go
index 671c1aa0..f4ebe215 100644
--- a/userapi/storage/sqlite3/accounts_table.go
+++ b/userapi/storage/sqlite3/accounts_table.go
@@ -34,7 +34,8 @@ const accountsSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS userapi_accounts (
-- The Matrix user ID localpart for this account
- localpart TEXT NOT NULL PRIMARY KEY,
+ localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
-- When this account was first created, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account.
@@ -48,25 +49,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
-- TODO:
-- upgraded_ts, devices, any email reset stuff?
);
+
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
`
const insertAccountSQL = "" +
- "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
+ "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" +
- "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
+ "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
const deactivateAccountSQL = "" +
- "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
+ "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
const selectAccountByLocalpartSQL = "" +
- "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
+ "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
const selectPasswordHashSQL = "" +
- "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
+ "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" +
- "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
+ "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1"
type accountsStatements struct {
db *sql.DB
@@ -119,16 +122,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) InsertAccount(
- ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
+ ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
+ hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
var err error
if accountType != api.AccountTypeAppService {
- _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
+ _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
} else {
- _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
+ _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
}
if err != nil {
return nil, err
@@ -136,42 +140,43 @@ func (s *accountsStatements) InsertAccount(
return &api.Account{
Localpart: localpart,
- UserID: userutil.MakeUserID(localpart, s.serverName),
- ServerName: s.serverName,
+ UserID: userutil.MakeUserID(localpart, serverName),
+ ServerName: serverName,
AppServiceID: appserviceID,
AccountType: accountType,
}, nil
}
func (s *accountsStatements) UpdatePassword(
- ctx context.Context, localpart, passwordHash string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
+ passwordHash string,
) (err error) {
- _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
+ _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
return
}
func (s *accountsStatements) DeactivateAccount(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
- _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
+ _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
return
}
func (s *accountsStatements) SelectPasswordHash(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (hash string, err error) {
- err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
+ err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
return
}
func (s *accountsStatements) SelectAccountByLocalpart(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
var acc api.Account
stmt := s.selectAccountByLocalpartStmt
- err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
+ err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
@@ -182,20 +187,18 @@ func (s *accountsStatements) SelectAccountByLocalpart(
acc.AppServiceID = appserviceIDPtr.String
}
- acc.UserID = userutil.MakeUserID(localpart, s.serverName)
- acc.ServerName = s.serverName
-
+ acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
return &acc, nil
}
func (s *accountsStatements) SelectNewNumericLocalpart(
- ctx context.Context, txn *sql.Tx,
+ ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
- err = stmt.QueryRowContext(ctx).Scan(&id)
+ err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
if err == sql.ErrNoRows {
return 1, nil
}
diff --git a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go
index 9158cb36..2de85005 100644
--- a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go
+++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go
@@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
+ server_name TEXT NOT NULL,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT,
diff --git a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
index a9224db6..636ce4ef 100644
--- a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
+++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
@@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
+ server_name TEXT NOT NULL,
created_ts BIGINT,
display_name TEXT,
last_seen_ts BIGINT,
diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go
index 230bc143..471e496c 100644
--- a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go
+++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go
@@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
+ server_name TEXT NOT NULL,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT,
diff --git a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go
new file mode 100644
index 00000000..c11ea684
--- /dev/null
+++ b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go
@@ -0,0 +1,108 @@
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+)
+
+var serverNamesTables = []string{
+ "userapi_accounts",
+ "userapi_account_datas",
+ "userapi_devices",
+ "userapi_notifications",
+ "userapi_openid_tokens",
+ "userapi_profiles",
+ "userapi_pushers",
+ "userapi_threepids",
+}
+
+// These tables have a PRIMARY KEY constraint which we need to drop so
+// that we can recreate a new unique index that contains the server name.
+var serverNamesDropPK = []string{
+ "userapi_accounts",
+ "userapi_account_datas",
+ "userapi_profiles",
+}
+
+// These indices are out of date so let's drop them. They will get recreated
+// automatically.
+var serverNamesDropIndex = []string{
+ "userapi_pusher_localpart_idx",
+ "userapi_pusher_app_id_pushkey_localpart_idx",
+}
+
+// I know what you're thinking: you're wondering "why doesn't this use $1
+// and pass variadic parameters to ExecContext?" — the answer is because
+// PostgreSQL doesn't expect the table name to be specified as a substituted
+// argument in that way so it results in a syntax error in the query.
+
+func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
+ for _, table := range serverNamesTables {
+ q := fmt.Sprintf(
+ "SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;",
+ pq.QuoteIdentifier(table),
+ )
+ var c int
+ if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 {
+ continue
+ }
+ q = fmt.Sprintf(
+ "SELECT COUNT(*) FROM pragma_table_info(%s) WHERE name='server_name'",
+ pq.QuoteIdentifier(table),
+ )
+ if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 1 {
+ logrus.Infof("Table %s already has column, skipping", table)
+ continue
+ }
+ if c == 0 {
+ q = fmt.Sprintf(
+ "ALTER TABLE %s ADD COLUMN server_name TEXT NOT NULL DEFAULT '';",
+ pq.QuoteIdentifier(table),
+ )
+ if _, err := tx.ExecContext(ctx, q); err != nil {
+ return fmt.Errorf("add server name to %q error: %w", table, err)
+ }
+ }
+ }
+ for _, table := range serverNamesDropPK {
+ q := fmt.Sprintf(
+ "SELECT COUNT(name), sql FROM sqlite_schema WHERE type='table' AND name=%s;",
+ pq.QuoteIdentifier(table),
+ )
+ var c int
+ var sql string
+ if err := tx.QueryRowContext(ctx, q).Scan(&c, &sql); err != nil || c == 0 {
+ continue
+ }
+ q = fmt.Sprintf(`
+ %s; -- create temporary table
+ INSERT INTO %s SELECT * FROM %s; -- copy data
+ DROP TABLE %s; -- drop original table
+ ALTER TABLE %s RENAME TO %s; -- rename new table
+ `,
+ strings.Replace(sql, table, table+"_tmp", 1), // create temporary table
+ table+"_tmp", table, // copy data
+ table, // drop original table
+ table+"_tmp", table, // rename new table
+ )
+ if _, err := tx.ExecContext(ctx, q); err != nil {
+ return fmt.Errorf("drop PK from %q error: %w", table, err)
+ }
+ }
+ for _, index := range serverNamesDropIndex {
+ q := fmt.Sprintf(
+ "DROP INDEX IF EXISTS %s;",
+ pq.QuoteIdentifier(index),
+ )
+ if _, err := tx.ExecContext(ctx, q); err != nil {
+ return fmt.Errorf("drop index %q error: %w", index, err)
+ }
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go
new file mode 100644
index 00000000..04a47fa7
--- /dev/null
+++ b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go
@@ -0,0 +1,28 @@
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// I know what you're thinking: you're wondering "why doesn't this use $1
+// and pass variadic parameters to ExecContext?" — the answer is because
+// PostgreSQL doesn't expect the table name to be specified as a substituted
+// argument in that way so it results in a syntax error in the query.
+
+func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
+ for _, table := range serverNamesTables {
+ q := fmt.Sprintf(
+ "UPDATE %s SET server_name = %s WHERE server_name = '';",
+ pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
+ )
+ if _, err := tx.ExecContext(ctx, q); err != nil {
+ return fmt.Errorf("write server names to %q error: %w", table, err)
+ }
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go
index e53a0806..c5db34bd 100644
--- a/userapi/storage/sqlite3/devices_table.go
+++ b/userapi/storage/sqlite3/devices_table.go
@@ -40,49 +40,50 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
+ server_name TEXT NOT NULL,
created_ts BIGINT,
display_name TEXT,
last_seen_ts BIGINT,
ip TEXT,
user_agent TEXT,
- UNIQUE (localpart, device_id)
+ UNIQUE (localpart, server_name, device_id)
);
`
const insertDeviceSQL = "" +
- "INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
- " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
+ "INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
+ " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM userapi_devices"
const selectDeviceByTokenSQL = "" +
- "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
+ "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
- "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
+ "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
const selectDevicesByLocalpartSQL = "" +
- "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
+ "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
- "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
+ "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
const deleteDeviceSQL = "" +
- "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
+ "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
const deleteDevicesByLocalpartSQL = "" +
- "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
+ "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
const deleteDevicesSQL = "" +
- "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)"
+ "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)"
const selectDevicesByIDSQL = "" +
- "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
+ "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
- "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
+ "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
type devicesStatements struct {
db *sql.DB
@@ -135,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) InsertDevice(
- ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
- displayName *string, ipAddr, userAgent string,
+ ctx context.Context, txn *sql.Tx, id string,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ accessToken string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
@@ -146,12 +148,12 @@ func (s *devicesStatements) InsertDevice(
return nil, err
}
sessionID++
- if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
+ if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
return &api.Device{
ID: id,
- UserID: userutil.MakeUserID(localpart, s.serverName),
+ UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
SessionID: sessionID,
LastSeenTS: createdTimeMS,
@@ -161,44 +163,52 @@ func (s *devicesStatements) InsertDevice(
}
func (s *devicesStatements) DeleteDevice(
- ctx context.Context, txn *sql.Tx, id, localpart string,
+ ctx context.Context, txn *sql.Tx, id string,
+ localpart string, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
- _, err := stmt.ExecContext(ctx, id, localpart)
+ _, err := stmt.ExecContext(ctx, id, localpart, serverName)
return err
}
func (s *devicesStatements) DeleteDevices(
- ctx context.Context, txn *sql.Tx, localpart string, devices []string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ devices []string,
) error {
- orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
+ orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1)
prep, err := s.db.Prepare(orig)
if err != nil {
return err
}
stmt := sqlutil.TxStmt(txn, prep)
- params := make([]interface{}, len(devices)+1)
+ params := make([]interface{}, len(devices)+2)
params[0] = localpart
+ params[1] = serverName
for i, v := range devices {
- params[i+1] = v
+ params[i+2] = v
}
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *devicesStatements) DeleteDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
- _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
+ _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
return err
}
func (s *devicesStatements) UpdateDeviceName(
- ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
- _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
+ _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
return err
}
@@ -207,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken(
) (*api.Device, error) {
var dev api.Device
var localpart string
+ var serverName gomatrixserverlib.ServerName
stmt := s.selectDeviceByTokenStmt
- err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
+ err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
if err == nil {
- dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ dev.UserID = userutil.MakeUserID(localpart, serverName)
dev.AccessToken = accessToken
}
return &dev, err
@@ -219,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) SelectDeviceByID(
- ctx context.Context, localpart, deviceID string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ deviceID string,
) (*api.Device, error) {
var dev api.Device
var displayName, ip sql.NullString
stmt := s.selectDeviceByIDStmt
var lastseenTS sql.NullInt64
- err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
+ err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil {
dev.ID = deviceID
- dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ dev.UserID = userutil.MakeUserID(localpart, serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
@@ -243,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID(
}
func (s *devicesStatements) SelectDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
- rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
+ rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
if err != nil {
return devices, err
@@ -276,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
dev.UserAgent = useragent.String
}
- dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
@@ -298,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
var devices []api.Device
var dev api.Device
var localpart string
+ var serverName gomatrixserverlib.ServerName
var displayName sql.NullString
var lastseents sql.NullInt64
for rows.Next() {
- if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
+ if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
return nil, err
}
if displayName.Valid {
@@ -310,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
if lastseents.Valid {
dev.LastSeenTS = lastseents.Int64
}
- dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
-func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
+func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
- _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
+ _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
return err
}
diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go
index a35ec7be..ef39d027 100644
--- a/userapi/storage/sqlite3/notifications_table.go
+++ b/userapi/storage/sqlite3/notifications_table.go
@@ -43,6 +43,7 @@ const notificationSchema = `
CREATE TABLE IF NOT EXISTS userapi_notifications (
id INTEGER PRIMARY KEY AUTOINCREMENT,
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
stream_pos BIGINT NOT NULL,
@@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
read BOOLEAN NOT NULL DEFAULT FALSE
);
-CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
-CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
-CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
`
const insertNotificationSQL = "" +
- "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
+ "INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const deleteNotificationsUpToSQL = "" +
- "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
+ "DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
const updateNotificationReadSQL = "" +
- "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
+ "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
const selectNotificationSQL = "" +
- "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
- "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
- ") AND NOT read ORDER BY localpart, id LIMIT $4"
+ "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
+ "(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read ORDER BY localpart, id LIMIT $5"
const selectNotificationCountSQL = "" +
- "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
- "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
+ "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
+ "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
") AND NOT read"
const selectRoomNotificationCountsSQL = "" +
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
- "WHERE localpart = $1 AND room_id = $2 AND NOT read"
+ "WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
const cleanNotificationsSQL = "" +
"DELETE FROM userapi_notifications WHERE" +
@@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
}
// Insert inserts a notification into the database.
-func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
+func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS
nn := *n
// Clears out fields that have their own columns to (1) shrink the
@@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
if err != nil {
return err
}
- _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
+ _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
return err
}
// DeleteUpTo deletes all previous notifications, up to and including the event.
-func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
- res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
+func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
if err != nil {
return false, err
}
@@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
}
// UpdateRead updates the "read" value for an event.
-func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
- res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
+func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
if err != nil {
return false, err
}
@@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
return nrows > 0, nil
}
-func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
- rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
+func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
if err != nil {
return nil, 0, err
@@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err()
}
-func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
- err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
+func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
+ err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
return
}
-func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
- err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
+func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
+ err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
return
}
diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go
index 875f1a9a..f0642974 100644
--- a/userapi/storage/sqlite3/openid_table.go
+++ b/userapi/storage/sqlite3/openid_table.go
@@ -3,6 +3,7 @@ package sqlite3
import (
"context"
"database/sql"
+ "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
@@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_at_ms BIGINT NOT NULL
);
`
const insertOpenIDTokenSQL = "" +
- "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
+ "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
const selectOpenIDTokenSQL = "" +
- "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
+ "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct {
db *sql.DB
@@ -56,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (
func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context,
txn *sql.Tx,
- token, localpart string,
+ token, localpart string, serverName gomatrixserverlib.ServerName,
expiresAtMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
- _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
+ _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
return
}
@@ -71,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
token string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
+ var localpart string
+ var serverName gomatrixserverlib.ServerName
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
- &openIDTokenAttrs.UserID,
+ &localpart, &serverName,
&openIDTokenAttrs.ExpiresAtMS,
)
+ openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db")
diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go
index b6130a1e..867026d7 100644
--- a/userapi/storage/sqlite3/profile_table.go
+++ b/userapi/storage/sqlite3/profile_table.go
@@ -23,36 +23,40 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
)
const profilesSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS userapi_profiles (
-- The Matrix user ID localpart for this account
- localpart TEXT NOT NULL PRIMARY KEY,
+ localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
-- The display name for this account
display_name TEXT,
-- The URL of the avatar for this account
avatar_url TEXT
);
+
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
`
const insertProfileSQL = "" +
- "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
+ "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
const selectProfileByLocalpartSQL = "" +
- "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
+ "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
const setAvatarURLSQL = "" +
- "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
+ "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" +
" RETURNING display_name"
const setDisplayNameSQL = "" +
- "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
+ "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" +
" RETURNING avatar_url"
const selectProfilesBySearchSQL = "" +
- "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
+ "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct {
db *sql.DB
@@ -83,18 +87,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P
}
func (s *profilesStatements) InsertProfile(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
) error {
- _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
+ _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
return err
}
func (s *profilesStatements) SelectProfileByLocalpart(
- ctx context.Context, localpart string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (*authtypes.Profile, error) {
var profile authtypes.Profile
- err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
- &profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
+ err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
+ &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
)
if err != nil {
return nil, err
@@ -103,13 +109,16 @@ func (s *profilesStatements) SelectProfileByLocalpart(
}
func (s *profilesStatements) SetAvatarURL(
- ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ avatarURL string,
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
- Localpart: localpart,
- AvatarURL: avatarURL,
+ Localpart: localpart,
+ ServerName: string(serverName),
+ AvatarURL: avatarURL,
}
- old, err := s.SelectProfileByLocalpart(ctx, localpart)
+ old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
if err != nil {
return old, false, err
}
@@ -117,18 +126,21 @@ func (s *profilesStatements) SetAvatarURL(
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
- err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
+ err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName)
return profile, true, err
}
func (s *profilesStatements) SetDisplayName(
- ctx context.Context, txn *sql.Tx, localpart string, displayName string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ displayName string,
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
+ ServerName: string(serverName),
DisplayName: displayName,
}
- old, err := s.SelectProfileByLocalpart(ctx, localpart)
+ old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
if err != nil {
return old, false, err
}
@@ -136,7 +148,7 @@ func (s *profilesStatements) SetDisplayName(
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
- err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
+ err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL)
return profile, true, err
}
@@ -154,7 +166,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
for rows.Next() {
var profile authtypes.Profile
- if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
+ if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err
}
if profile.Localpart != s.serverNoticesLocalpart {
diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go
index 4de0a9f0..c9d451dc 100644
--- a/userapi/storage/sqlite3/pusher_table.go
+++ b/userapi/storage/sqlite3/pusher_table.go
@@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
)
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The Matrix user ID localpart for this pusher
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
session_id BIGINT DEFAULT NULL,
profile_tag TEXT,
kind TEXT NOT NULL,
@@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart.
-CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
+CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
-- Pushkey must be unique for a given user and app.
-CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
`
const insertPusherSQL = "" +
- "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
- "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
- "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
+ "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
+ "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
+ "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
const selectPushersSQL = "" +
- "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
+ "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
const deletePusherSQL = "" +
- "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
@@ -95,18 +97,19 @@ type pushersStatements struct {
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
- pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
+ pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
+ localpart string, serverName gomatrixserverlib.ServerName,
) error {
- _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
- logrus.Debugf("Created pusher %d", session_id)
+ _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
return err
}
func (s *pushersStatements) SelectPushers(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx,
+ localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) {
pushers := []api.Pusher{}
- rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
+ rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return pushers, err
@@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
// deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher(
- ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
+ ctx context.Context, txn *sql.Tx, appid, pushkey,
+ localpart string, serverName gomatrixserverlib.ServerName,
) error {
- _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
+ _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
return err
}
diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go
index dd33dc0c..85a1f706 100644
--- a/userapi/storage/sqlite3/storage.go
+++ b/userapi/storage/sqlite3/storage.go
@@ -15,6 +15,8 @@
package sqlite3
import (
+ "context"
+ "database/sql"
"fmt"
"time"
@@ -41,18 +43,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables,
})
+ m.AddMigrations(sqlutil.Migration{
+ Version: "userapi: server names",
+ Up: func(ctx context.Context, txn *sql.Tx) error {
+ return deltas.UpServerNames(ctx, txn, serverName)
+ },
+ })
if err = m.Up(base.Context()); err != nil {
return nil, err
}
- accountDataTable, err := NewSQLiteAccountDataTable(db)
- if err != nil {
- return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
- }
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
}
+ accountDataTable, err := NewSQLiteAccountDataTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
+ }
devicesTable, err := NewSQLiteDevicesTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
@@ -93,6 +101,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
}
+
+ m = sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "userapi: server names populate",
+ Up: func(ctx context.Context, txn *sql.Tx) error {
+ return deltas.UpServerNamesPopulate(ctx, txn, serverName)
+ },
+ })
+ if err = m.Up(base.Context()); err != nil {
+ return nil, err
+ }
+
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go
index 73af139d..2db7d588 100644
--- a/userapi/storage/sqlite3/threepid_table.go
+++ b/userapi/storage/sqlite3/threepid_table.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
@@ -34,21 +35,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
medium TEXT NOT NULL DEFAULT 'email',
-- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL,
+ server_name TEXT NOT NULL,
PRIMARY KEY(threepid, medium)
);
-CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart);
+CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart, server_name);
`
const selectLocalpartForThreePIDSQL = "" +
- "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
+ "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
- "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
+ "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
const insertThreePIDSQL = "" +
- "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
+ "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
const deleteThreePIDSQL = "" +
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
@@ -79,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
-) (localpart string, err error) {
+) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
- err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
+ err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
if err == sql.ErrNoRows {
- return "", nil
+ return "", "", nil
}
return
}
func (s *threepidStatements) SelectThreePIDsForLocalpart(
- ctx context.Context, localpart string,
+ ctx context.Context,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) {
- rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
+ rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return
}
@@ -113,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
}
func (s *threepidStatements) InsertThreePID(
- ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
+ ctx context.Context, txn *sql.Tx, threepid, medium,
+ localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
- _, err = stmt.ExecContext(ctx, threepid, medium, localpart)
+ _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
return err
}
diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go
index 354f085f..23aafff0 100644
--- a/userapi/storage/storage_test.go
+++ b/userapi/storage/storage_test.go
@@ -50,25 +50,25 @@ func Test_AccountData(t *testing.T) {
db, close := mustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
- localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
room := test.NewRoom(t, alice)
events := room.Events()
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
- err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom)
+ err = db.SaveAccountData(ctx, localpart, domain, room.ID, "m.fully_read", contentRoom)
assert.NoError(t, err, "unable to save account data")
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
- err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal)
+ err = db.SaveAccountData(ctx, localpart, domain, "", "im.vector.setting.breadcrumbs", contentGlobal)
assert.NoError(t, err, "unable to save account data")
- accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read")
+ accountData, err := db.GetAccountDataByType(ctx, localpart, domain, room.ID, "m.fully_read")
assert.NoError(t, err, "unable to get account data by type")
assert.Equal(t, contentRoom, accountData)
- globalData, roomData, err := db.GetAccountData(ctx, localpart)
+ globalData, roomData, err := db.GetAccountData(ctx, localpart, domain)
assert.NoError(t, err)
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
@@ -81,78 +81,78 @@ func Test_Accounts(t *testing.T) {
db, close := mustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
- aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
- accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
+ accAlice, err := db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account")
// verify the newly create account is the same as returned by CreateAccount
var accGet *api.Account
- accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
+ accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing")
assert.NoError(t, err, "failed to get account by password")
assert.Equal(t, accAlice, accGet)
- accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
+ accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to get account by localpart")
assert.Equal(t, accAlice, accGet)
// check account availability
- available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
+ available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to checkout account availability")
assert.Equal(t, false, available)
- available, err = db.CheckAccountAvailability(ctx, "unusedname")
+ available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain)
assert.NoError(t, err, "failed to checkout account availability")
assert.Equal(t, true, available)
// get guest account numeric aliceLocalpart
- first, err := db.GetNewNumericLocalpart(ctx)
+ first, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
assert.NoError(t, err, "failed to get new numeric localpart")
// Create a new account to verify the numeric localpart is updated
- _, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
+ _, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest)
assert.NoError(t, err, "failed to create account")
- second, err := db.GetNewNumericLocalpart(ctx)
+ second, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
assert.NoError(t, err)
assert.Greater(t, second, first)
// update password for alice
- err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
+ err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.NoError(t, err, "failed to update password")
- accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
+ accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.NoError(t, err, "failed to get account by new password")
assert.Equal(t, accAlice, accGet)
// deactivate account
- err = db.DeactivateAccount(ctx, aliceLocalpart)
+ err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to deactivate account")
// This should fail now, as the account is deactivated
- _, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
+ _, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.Error(t, err, "expected an error, got none")
- _, err = db.GetAccountByLocalpart(ctx, "unusename")
+ _, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain)
assert.Error(t, err, "expected an error for non existent localpart")
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart
// if there's already a user without a localpart in the database
- _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser)
+ _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeUser)
assert.NoError(t, err)
// test getting a numeric localpart, with an existing user without a localpart
- _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
+ _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
assert.NoError(t, err)
// Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type
- _, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser)
+ _, err = db.CreateAccount(ctx, "2147483650", aliceDomain, "", "", api.AccountTypeUser)
assert.NoError(t, err)
// Now try to create a new guest user
- _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
+ _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
assert.NoError(t, err)
})
}
func Test_Devices(t *testing.T) {
alice := test.NewUser(t)
- localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
deviceID := util.RandomString(8)
accessToken := util.RandomString(16)
@@ -161,10 +161,10 @@ func Test_Devices(t *testing.T) {
db, close := mustCreateDatabase(t, dbType)
defer close()
- deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
+ deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create deviceWithoutID")
- gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
+ gotDevice, err := db.GetDeviceByID(ctx, localpart, domain, deviceID)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
@@ -174,14 +174,14 @@ func Test_Devices(t *testing.T) {
// create a device without existing device ID
accessToken = util.RandomString(16)
- deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
+ deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create deviceWithoutID")
- gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
+ gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, domain, deviceWithoutID.ID)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
// Get devices
- devices, err := db.GetDevicesByLocalpart(ctx, localpart)
+ devices, err := db.GetDevicesByLocalpart(ctx, localpart, domain)
assert.NoError(t, err, "unable to get devices by localpart")
assert.Equal(t, 2, len(devices))
deviceIDs := make([]string, 0, len(devices))
@@ -195,15 +195,15 @@ func Test_Devices(t *testing.T) {
// Update device
newName := "new display name"
- err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName)
+ err = db.UpdateDevice(ctx, localpart, domain, deviceWithID.ID, &newName)
assert.NoError(t, err, "unable to update device displayname")
updatedAfterTimestamp := time.Now().Unix()
- err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web")
+ err = db.UpdateDeviceLastSeen(ctx, localpart, domain, deviceWithID.ID, "127.0.0.1", "Element Web")
assert.NoError(t, err, "unable to update device last seen")
deviceWithID.DisplayName = newName
deviceWithID.LastSeenIP = "127.0.0.1"
- gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID)
+ gotDevice, err = db.GetDeviceByID(ctx, localpart, domain, deviceWithID.ID)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 2, len(devices))
assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName)
@@ -213,20 +213,20 @@ func Test_Devices(t *testing.T) {
// create one more device and remove the devices step by step
newDeviceID := util.RandomString(16)
accessToken = util.RandomString(16)
- _, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "")
+ _, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create new device")
- devices, err = db.GetDevicesByLocalpart(ctx, localpart)
+ devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 3, len(devices))
- err = db.RemoveDevices(ctx, localpart, deviceIDs)
+ err = db.RemoveDevices(ctx, localpart, domain, deviceIDs)
assert.NoError(t, err, "unable to remove devices")
- devices, err = db.GetDevicesByLocalpart(ctx, localpart)
+ devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 1, len(devices))
- deleted, err := db.RemoveAllDevices(ctx, localpart, "")
+ deleted, err := db.RemoveAllDevices(ctx, localpart, domain, "")
assert.NoError(t, err, "unable to remove all devices")
assert.Equal(t, 1, len(deleted))
assert.Equal(t, newDeviceID, deleted[0].ID)
@@ -364,7 +364,7 @@ func Test_OpenID(t *testing.T) {
func Test_Profile(t *testing.T) {
alice := test.NewUser(t)
- aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@@ -372,30 +372,33 @@ func Test_Profile(t *testing.T) {
defer close()
// create account, which also creates a profile
- _, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
+ _, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account")
- gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
+ gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get profile by localpart")
- wantProfile := &authtypes.Profile{Localpart: aliceLocalpart}
+ wantProfile := &authtypes.Profile{
+ Localpart: aliceLocalpart,
+ ServerName: string(aliceDomain),
+ }
assert.Equal(t, wantProfile, gotProfile)
// set avatar & displayname
wantProfile.DisplayName = "Alice"
- gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice")
+ gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, aliceDomain, "Alice")
assert.Equal(t, wantProfile, gotProfile)
assert.NoError(t, err, "unable to set displayname")
assert.True(t, changed)
wantProfile.AvatarURL = "mxc://aliceAvatar"
- gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
+ gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile)
assert.True(t, changed)
// Setting the same avatar again doesn't change anything
wantProfile.AvatarURL = "mxc://aliceAvatar"
- gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
+ gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile)
assert.False(t, changed)
@@ -410,7 +413,7 @@ func Test_Profile(t *testing.T) {
func Test_Pusher(t *testing.T) {
alice := test.NewUser(t)
- aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@@ -432,11 +435,11 @@ func Test_Pusher(t *testing.T) {
ProfileTag: util.RandomString(8),
Language: util.RandomString(2),
}
- err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart)
+ err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to upsert pusher")
// check it was actually persisted
- gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
+ gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, i+1, len(gotPushers))
assert.Equal(t, wantPusher, gotPushers[i])
@@ -444,16 +447,16 @@ func Test_Pusher(t *testing.T) {
}
// remove single pusher
- err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart)
+ err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to remove pusher")
- gotPushers, err := db.GetPushers(ctx, aliceLocalpart)
+ gotPushers, err := db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, 1, len(gotPushers))
// remove last pusher
err = db.RemovePushers(ctx, appID, pushKeys[1])
assert.NoError(t, err, "unable to remove pusher")
- gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
+ gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, 0, len(gotPushers))
})
@@ -461,7 +464,7 @@ func Test_Pusher(t *testing.T) {
func Test_ThreePID(t *testing.T) {
alice := test.NewUser(t)
- aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@@ -469,15 +472,16 @@ func Test_ThreePID(t *testing.T) {
defer close()
threePID := util.RandomString(8)
medium := util.RandomString(8)
- err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium)
+ err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium)
assert.NoError(t, err, "unable to save threepid association")
// get the stored threepid
- gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
+ gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
assert.NoError(t, err, "unable to get localpart for threepid")
assert.Equal(t, aliceLocalpart, gotLocalpart)
+ assert.Equal(t, aliceDomain, gotDomain)
- threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
+ threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get threepids for localpart")
assert.Equal(t, 1, len(threepids))
assert.Equal(t, authtypes.ThreePID{
@@ -490,7 +494,7 @@ func Test_ThreePID(t *testing.T) {
assert.NoError(t, err, "unexpected error")
// verify it was deleted
- threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
+ threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get threepids for localpart")
assert.Equal(t, 0, len(threepids))
})
@@ -498,7 +502,7 @@ func Test_ThreePID(t *testing.T) {
func Test_Notification(t *testing.T) {
alice := test.NewUser(t)
- aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
+ aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
room := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice)
@@ -526,34 +530,34 @@ func Test_Notification(t *testing.T) {
RoomID: roomID,
TS: gomatrixserverlib.AsTimestamp(ts),
}
- err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification)
+ err = db.InsertNotification(ctx, aliceLocalpart, aliceDomain, eventID, uint64(i+1), nil, notification)
assert.NoError(t, err, "unable to insert notification")
}
// get notifications
- count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications)
+ count, err := db.GetNotificationCount(ctx, aliceLocalpart, aliceDomain, tables.AllNotifications)
assert.NoError(t, err, "unable to get notification count")
assert.Equal(t, int64(10), count)
- notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications)
+ notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, aliceDomain, 0, 15, tables.AllNotifications)
assert.NoError(t, err, "unable to get notifications")
assert.Equal(t, int64(10), count)
assert.Equal(t, 10, len(notifs))
// ... for a specific room
- total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
+ total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
assert.NoError(t, err, "unable to get notifications for room")
assert.Equal(t, int64(4), total)
// mark notification as read
- affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true)
+ affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, aliceDomain, room2.ID, 7, true)
assert.NoError(t, err, "unable to set notifications read")
assert.True(t, affected)
// this should delete 2 notifications
- affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8)
+ affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, aliceDomain, room2.ID, 8)
assert.NoError(t, err, "unable to set notifications read")
assert.True(t, affected)
- total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
+ total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
assert.NoError(t, err, "unable to get notifications for room")
assert.Equal(t, int64(2), total)
@@ -562,7 +566,7 @@ func Test_Notification(t *testing.T) {
assert.NoError(t, err)
// this should now return 0 notifications
- total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
+ total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
assert.NoError(t, err, "unable to get notifications for room")
assert.Equal(t, int64(0), total)
})
diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go
index 5e1dd097..e14776cf 100644
--- a/userapi/storage/tables/interface.go
+++ b/userapi/storage/tables/interface.go
@@ -28,31 +28,31 @@ import (
)
type AccountDataTable interface {
- InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error
- SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
- SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
+ InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
+ SelectAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
+ SelectAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
}
type AccountsTable interface {
- InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
- UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
- DeactivateAccount(ctx context.Context, localpart string) (err error)
- SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
- SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
- SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
+ InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
+ UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error)
+ DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
+ SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error)
+ SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
+ SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error)
}
type DevicesTable interface {
- InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
- DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
- DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
- DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
- UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error
+ InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
+ DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error
+ DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
+ DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error
+ UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
- SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
- SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error)
+ SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
+ SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error)
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
- UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error
+ UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
}
type KeyBackupTable interface {
@@ -79,40 +79,40 @@ type LoginTokenTable interface {
}
type OpenIDTable interface {
- InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error)
+ InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error)
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
}
type ProfileTable interface {
- InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
- SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
- SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
- SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error)
+ InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error
+ SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
+ SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
+ SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
}
type ThreePIDTable interface {
- SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error)
- SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
- InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
+ SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
+ SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
+ InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error)
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
}
type PusherTable interface {
- InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
- SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
- DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
+ InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error
+ SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
+ DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
}
type NotificationTable interface {
Clean(ctx context.Context, txn *sql.Tx) error
- Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error
- DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error)
- UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error)
- Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
- SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
- SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
+ Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error
+ DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error)
+ UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error)
+ Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
+ SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter NotificationFilter) (int64, error)
+ SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
}
type StatsTable interface {
diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go
index a547423b..b088d15c 100644
--- a/userapi/storage/tables/stats_table_test.go
+++ b/userapi/storage/tables/stats_table_test.go
@@ -79,6 +79,7 @@ func mustMakeAccountAndDevice(
accDB tables.AccountsTable,
devDB tables.DevicesTable,
localpart string,
+ serverName gomatrixserverlib.ServerName, // nolint:unparam
accType api.AccountType,
userAgent string,
) {
@@ -89,11 +90,11 @@ func mustMakeAccountAndDevice(
appServiceID = util.RandomString(16)
}
- _, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType)
+ _, err := accDB.InsertAccount(ctx, nil, localpart, serverName, "", appServiceID, accType)
if err != nil {
t.Fatalf("unable to create account: %v", err)
}
- _, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent)
+ _, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, serverName, util.RandomString(16), nil, "", userAgent)
if err != nil {
t.Fatalf("unable to create device: %v", err)
}
@@ -150,12 +151,12 @@ func Test_UserStatistics(t *testing.T) {
})
t.Run("Want Users", func(t *testing.T) {
- mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android")
- mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS")
- mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web")
- mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron")
- mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko")
- mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko")
+ mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", "localhost", api.AccountTypeUser, "Element Android")
+ mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", "localhost", api.AccountTypeUser, "Element iOS")
+ mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", "localhost", api.AccountTypeUser, "Element web")
+ mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", "localhost", api.AccountTypeGuest, "Element Electron")
+ mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", "localhost", api.AccountTypeAdmin, "gecko")
+ mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", "localhost", api.AccountTypeAppService, "gecko")
gotStats, _, err := statsDB.UserStatistics(ctx, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go
index 2a43c0bd..23acfa6f 100644
--- a/userapi/userapi_test.go
+++ b/userapi/userapi_test.go
@@ -80,14 +80,14 @@ func TestQueryProfile(t *testing.T) {
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
defer close()
- _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser)
+ _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser)
if err != nil {
t.Fatalf("failed to make account: %s", err)
}
- if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
+ if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil {
t.Fatalf("failed to set avatar url: %s", err)
}
- if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
+ if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil {
t.Fatalf("failed to set display name: %s", err)
}
@@ -164,7 +164,7 @@ func TestPasswordlessLoginFails(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
defer close()
- _, err := accountDB.CreateAccount(ctx, "auser", "", "", api.AccountTypeAppService)
+ _, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService)
if err != nil {
t.Fatalf("failed to make account: %s", err)
}
@@ -190,7 +190,7 @@ func TestLoginToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
defer close()
- _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser)
+ _, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser)
if err != nil {
t.Fatalf("failed to make account: %s", err)
}
diff --git a/userapi/util/devices.go b/userapi/util/devices.go
index cbf3bd28..c55fc799 100644
--- a/userapi/util/devices.go
+++ b/userapi/util/devices.go
@@ -2,10 +2,12 @@ package util
import (
"context"
+ "fmt"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage"
+ "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
@@ -17,10 +19,10 @@ type PusherDevice struct {
}
// GetPushDevices pushes to the configured devices of a local user.
-func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
- pushers, err := db.GetPushers(ctx, localpart)
+func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
+ pushers, err := db.GetPushers(ctx, localpart, serverName)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("db.GetPushers: %w", err)
}
devices := make([]*PusherDevice, 0, len(pushers))
diff --git a/userapi/util/notify.go b/userapi/util/notify.go
index ff206bd3..fc0ab39b 100644
--- a/userapi/util/notify.go
+++ b/userapi/util/notify.go
@@ -8,6 +8,7 @@ import (
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
@@ -16,8 +17,8 @@ import (
// a single goroutine is started when talking to the Push
// gateways. There is no way to know when the background goroutine has
// finished.
-func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error {
- pusherDevices, err := GetPushDevices(ctx, localpart, nil, db)
+func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error {
+ pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db)
if err != nil {
return err
}
@@ -26,7 +27,7 @@ func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, loc
return nil
}
- userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications)
+ userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, serverName, tables.AllNotifications)
if err != nil {
return err
}