aboutsummaryrefslogtreecommitdiff
path: root/syncapi
diff options
context:
space:
mode:
Diffstat (limited to 'syncapi')
-rw-r--r--syncapi/internal/keychange.go33
-rw-r--r--syncapi/internal/keychange_test.go31
-rw-r--r--syncapi/storage/interface.go7
-rw-r--r--syncapi/storage/postgres/current_room_state_table.go30
-rw-r--r--syncapi/storage/shared/syncserver.go4
-rw-r--r--syncapi/storage/sqlite3/current_room_state_table.go34
-rw-r--r--syncapi/storage/tables/interface.go2
-rw-r--r--syncapi/streams/stream_devicelist.go2
-rw-r--r--syncapi/sync/requestpool.go2
9 files changed, 118 insertions, 27 deletions
diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go
index d96718d2..03df9285 100644
--- a/syncapi/internal/keychange.go
+++ b/syncapi/internal/keychange.go
@@ -21,6 +21,7 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api"
keytypes "github.com/matrix-org/dendrite/keyserver/types"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@@ -46,7 +47,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, devi
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information.
func DeviceListCatchup(
- ctx context.Context, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
+ ctx context.Context, db storage.SharedUsers, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
userID string, res *types.Response, from, to types.StreamPosition,
) (newPos types.StreamPosition, hasNew bool, err error) {
@@ -93,7 +94,7 @@ func DeviceListCatchup(
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
var sharedUsersMap map[string]int
- sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs)
+ sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, db, userID, queryRes.UserIDs)
util.GetLogger(ctx).Debugf(
"QueryKeyChanges request off=%d,to=%d response off=%d uids=%v",
offset, toOffset, queryRes.Offset, queryRes.UserIDs,
@@ -215,30 +216,28 @@ func TrackChangedUsers(
return changed, left, nil
}
+// filterSharedUsers takes a list of remote users whose keys have changed and filters
+// it down to include only users who the requesting user shares a room with.
func filterSharedUsers(
- ctx context.Context, rsAPI roomserverAPI.SyncRoomserverAPI, userID string, usersWithChangedKeys []string,
+ ctx context.Context, db storage.SharedUsers, userID string, usersWithChangedKeys []string,
) (map[string]int, []string) {
- var result []string
- var sharedUsersRes roomserverAPI.QuerySharedUsersResponse
- err := rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{
- UserID: userID,
- OtherUserIDs: usersWithChangedKeys,
- }, &sharedUsersRes)
+ sharedUsersMap := make(map[string]int, len(usersWithChangedKeys))
+ for _, userID := range usersWithChangedKeys {
+ sharedUsersMap[userID] = 0
+ }
+ sharedUsers, err := db.SharedUsers(ctx, userID, usersWithChangedKeys)
if err != nil {
// default to all users so we do needless queries rather than miss some important device update
return nil, usersWithChangedKeys
}
+ for _, userID := range sharedUsers {
+ sharedUsersMap[userID]++
+ }
// We forcibly put ourselves in this list because we should be notified about our own device updates
// and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
// be notified about key changes.
- sharedUsersRes.UserIDsToCount[userID] = 1
-
- for _, uid := range usersWithChangedKeys {
- if sharedUsersRes.UserIDsToCount[uid] > 0 {
- result = append(result, uid)
- }
- }
- return sharedUsersRes.UserIDsToCount, result
+ sharedUsersMap[userID] = 1
+ return sharedUsersMap, sharedUsers
}
func joinedRooms(res *types.Response, userID string) []string {
diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go
index 219b35e2..79ed440e 100644
--- a/syncapi/internal/keychange_test.go
+++ b/syncapi/internal/keychange_test.go
@@ -11,6 +11,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
)
var (
@@ -105,6 +106,22 @@ func (s *mockRoomserverAPI) QuerySharedUsers(ctx context.Context, req *api.Query
return nil
}
+// This is actually a database function, but seeing as we track the state inside the
+// *mockRoomserverAPI, we'll just comply with the interface here instead.
+func (s *mockRoomserverAPI) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
+ commonUsers := []string{}
+ for _, members := range s.roomIDToJoinedMembers {
+ for _, member := range members {
+ for _, userID := range otherUserIDs {
+ if member == userID {
+ commonUsers = append(commonUsers, userID)
+ }
+ }
+ }
+ }
+ return util.UniqueStrings(commonUsers), nil
+}
+
type wantCatchup struct {
hasNew bool
changed []string
@@ -178,7 +195,7 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
"!another:room": {syncingUser},
},
}
- _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
+ _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err)
}
@@ -201,7 +218,7 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
"!another:room": {syncingUser},
},
}
- _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
+ _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err)
}
@@ -224,7 +241,7 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
"!another:room": {syncingUser, existingUser},
},
}
- _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
+ _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil {
t.Fatalf("Catchup returned an error: %s", err)
}
@@ -246,7 +263,7 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
"!another:room": {syncingUser, existingUser},
},
}
- _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
+ _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err)
}
@@ -305,7 +322,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
roomID: {syncingUser, existingUser},
},
}
- _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
+ _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err)
}
@@ -333,7 +350,7 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) {
"!another:room": {syncingUser},
},
}
- _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
+ _, hasNew, err := DeviceListCatchup(context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil {
t.Fatalf("Catchup returned an error: %s", err)
}
@@ -419,7 +436,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
},
}
_, hasNew, err := DeviceListCatchup(
- context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken,
+ context.Background(), rsAPI, &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken,
)
if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err)
diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go
index 5a036d88..05542603 100644
--- a/syncapi/storage/interface.go
+++ b/syncapi/storage/interface.go
@@ -27,6 +27,8 @@ import (
type Database interface {
Presence
+ SharedUsers
+
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
@@ -165,3 +167,8 @@ type Presence interface {
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
}
+
+type SharedUsers interface {
+ // SharedUsers returns a subset of otherUserIDs that share a room with userID.
+ SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error)
+}
diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go
index 8ee387b3..c4667baf 100644
--- a/syncapi/storage/postgres/current_room_state_table.go
+++ b/syncapi/storage/postgres/current_room_state_table.go
@@ -107,6 +107,11 @@ const selectEventsWithEventIDsSQL = "" +
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
" FROM syncapi_current_room_state WHERE event_id = ANY($1)"
+const selectSharedUsersSQL = "" +
+ "SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
+ " SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
+ ") AND state_key = ANY($2) AND membership='join';"
+
type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
@@ -118,6 +123,7 @@ type currentRoomStateStatements struct {
selectJoinedUsersInRoomStmt *sql.Stmt
selectEventsWithEventIDsStmt *sql.Stmt
selectStateEventStmt *sql.Stmt
+ selectSharedUsersStmt *sql.Stmt
}
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@@ -156,6 +162,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
return nil, err
}
+ if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil {
+ return nil, err
+ }
return s, nil
}
@@ -379,3 +388,24 @@ func (s *currentRoomStateStatements) SelectStateEvent(
}
return &ev, err
}
+
+func (s *currentRoomStateStatements) SelectSharedUsers(
+ ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
+) ([]string, error) {
+ stmt := sqlutil.TxStmt(txn, s.selectSharedUsersStmt)
+ rows, err := stmt.QueryContext(ctx, userID, otherUserIDs)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed")
+
+ var stateKey string
+ result := make([]string, 0, len(otherUserIDs))
+ for rows.Next() {
+ if err := rows.Scan(&stateKey); err != nil {
+ return nil, err
+ }
+ result = append(result, stateKey)
+ }
+ return result, rows.Err()
+}
diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go
index 76114aff..d1c5e2d1 100644
--- a/syncapi/storage/shared/syncserver.go
+++ b/syncapi/storage/shared/syncserver.go
@@ -176,6 +176,10 @@ func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]t
return d.Peeks.SelectPeekingDevices(ctx)
}
+func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
+ return d.CurrentRoomState.SelectSharedUsers(ctx, nil, userID, otherUserIDs)
+}
+
func (d *Database) GetStateEvent(
ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) {
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
index f0a1c7bb..376c3a3d 100644
--- a/syncapi/storage/sqlite3/current_room_state_table.go
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -91,6 +91,11 @@ const selectEventsWithEventIDsSQL = "" +
"SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
+const selectSharedUsersSQL = "" +
+ "SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
+ " SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
+ ") AND state_key IN ($2) AND membership='join';"
+
type currentRoomStateStatements struct {
db *sql.DB
streamIDStatements *StreamIDStatements
@@ -100,8 +105,9 @@ type currentRoomStateStatements struct {
selectRoomIDsWithMembershipStmt *sql.Stmt
selectRoomIDsWithAnyMembershipStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt
- //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
+ //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic
selectStateEventStmt *sql.Stmt
+ //selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic
}
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
@@ -396,3 +402,29 @@ func (s *currentRoomStateStatements) SelectStateEvent(
}
return &ev, err
}
+
+func (s *currentRoomStateStatements) SelectSharedUsers(
+ ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
+) ([]string, error) {
+ query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
+ stmt, err := s.db.Prepare(query)
+ if err != nil {
+ return nil, fmt.Errorf("SelectSharedUsers s.db.Prepare: %w", err)
+ }
+ defer internal.CloseAndLogIfError(ctx, stmt, "SelectSharedUsers: stmt.close() failed")
+ rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, userID, otherUserIDs)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed")
+
+ var stateKey string
+ result := make([]string, 0, len(otherUserIDs))
+ for rows.Next() {
+ if err := rows.Scan(&stateKey); err != nil {
+ return nil, err
+ }
+ result = append(result, stateKey)
+ }
+ return result, rows.Err()
+}
diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go
index ccdebfdb..08568d9a 100644
--- a/syncapi/storage/tables/interface.go
+++ b/syncapi/storage/tables/interface.go
@@ -104,6 +104,8 @@ type CurrentRoomState interface {
SelectJoinedUsers(ctx context.Context) (map[string][]string, error)
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
SelectJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
+ // SelectSharedUsers returns a subset of otherUserIDs that share a room with userID.
+ SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error)
}
// BackwardsExtremities keeps track of backwards extremities for a room.
diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go
index f4209951..5448ee5b 100644
--- a/syncapi/streams/stream_devicelist.go
+++ b/syncapi/streams/stream_devicelist.go
@@ -28,7 +28,7 @@ func (p *DeviceListStreamProvider) IncrementalSync(
from, to types.StreamPosition,
) types.StreamPosition {
var err error
- to, _, err = internal.DeviceListCatchup(context.Background(), p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
+ to, _, err = internal.DeviceListCatchup(context.Background(), p.DB, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
return from
diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go
index 6f0849e0..b6b4779a 100644
--- a/syncapi/sync/requestpool.go
+++ b/syncapi/sync/requestpool.go
@@ -429,7 +429,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
}
rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), syncReq, fromToken.PDUPosition, toToken.PDUPosition)
_, _, err = internal.DeviceListCatchup(
- req.Context(), rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
+ req.Context(), rp.db, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition,
)
if err != nil {