aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2022-08-03 18:35:17 +0200
committerGitHub <noreply@github.com>2022-08-03 18:35:17 +0200
commit9fe509b18da997e294813fcc5f46a45b7f6e6784 (patch)
treeeb329616f26d8391a738b3ef2f9ab4994d79fbc7
parent2250768be16bd0e6b3a6a72b5e55eb3e2ad6e3c6 (diff)
Fix syncapi shared users query & device lists (#2614)
* Fix query issue, only add "changed" users if we actually share a room * Avoid log spam if context is done * Undo changes to filterSharedUsers * Add logging again.. * Fix SQLite shared users query * Change query to include invited users
-rw-r--r--keyserver/internal/internal.go11
-rw-r--r--syncapi/internal/keychange.go44
-rw-r--r--syncapi/internal/keychange_test.go1
-rw-r--r--syncapi/storage/postgres/current_room_state_table.go4
-rw-r--r--syncapi/storage/sqlite3/current_room_state_table.go43
5 files changed, 61 insertions, 42 deletions
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index c146b2aa..91f01151 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -18,6 +18,7 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"fmt"
"sync"
"time"
@@ -314,6 +315,11 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
for targetKeyID := range masterKey.Keys {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
if err != nil {
+ // Stop executing the function if the context was canceled/the deadline was exceeded,
+ // as we can't continue without a valid context.
+ if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+ return
+ }
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue
}
@@ -335,6 +341,11 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
for targetKeyID, key := range forUserID {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
if err != nil {
+ // Stop executing the function if the context was canceled/the deadline was exceeded,
+ // as we can't continue without a valid context.
+ if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+ return
+ }
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue
}
diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go
index 03df9285..4bf54cae 100644
--- a/syncapi/internal/keychange.go
+++ b/syncapi/internal/keychange.go
@@ -25,10 +25,9 @@ import (
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
)
-const DeviceListLogName = "dl"
-
// DeviceOTKCounts adds one-time key counts to the /sync response
func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error {
var queryRes keyapi.QueryOneTimeKeysResponse
@@ -93,18 +92,13 @@ func DeviceListCatchup(
queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...)
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
- var sharedUsersMap map[string]int
- sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, db, userID, queryRes.UserIDs)
- util.GetLogger(ctx).Debugf(
- "QueryKeyChanges request off=%d,to=%d response off=%d uids=%v",
- offset, toOffset, queryRes.Offset, queryRes.UserIDs,
- )
+ sharedUsersMap := filterSharedUsers(ctx, db, userID, queryRes.UserIDs)
userSet := make(map[string]bool)
for _, userID := range res.DeviceLists.Changed {
userSet[userID] = true
}
- for _, userID := range queryRes.UserIDs {
- if !userSet[userID] {
+ for userID, count := range sharedUsersMap {
+ if !userSet[userID] && count > 0 {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
hasNew = true
userSet[userID] = true
@@ -113,7 +107,7 @@ func DeviceListCatchup(
// Finally, add in users who have joined or left.
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
for _, userID := range joinUserIDs {
- if !userSet[userID] {
+ if !userSet[userID] && sharedUsersMap[userID] > 0 {
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
hasNew = true
userSet[userID] = true
@@ -126,6 +120,13 @@ func DeviceListCatchup(
}
}
+ util.GetLogger(ctx).WithFields(logrus.Fields{
+ "user_id": userID,
+ "from": offset,
+ "to": toOffset,
+ "response_offset": queryRes.Offset,
+ }).Debugf("QueryKeyChanges request result: %+v", res.DeviceLists)
+
return types.StreamPosition(queryRes.Offset), hasNew, nil
}
@@ -220,24 +221,27 @@ func TrackChangedUsers(
// it down to include only users who the requesting user shares a room with.
func filterSharedUsers(
ctx context.Context, db storage.SharedUsers, userID string, usersWithChangedKeys []string,
-) (map[string]int, []string) {
+) map[string]int {
sharedUsersMap := make(map[string]int, len(usersWithChangedKeys))
- for _, userID := range usersWithChangedKeys {
- sharedUsersMap[userID] = 0
+ for _, changedUserID := range usersWithChangedKeys {
+ sharedUsersMap[changedUserID] = 0
+ if changedUserID == userID {
+ // We forcibly put ourselves in this list because we should be notified about our own device updates
+ // and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
+ // be notified about key changes.
+ sharedUsersMap[userID] = 1
+ }
}
sharedUsers, err := db.SharedUsers(ctx, userID, usersWithChangedKeys)
if err != nil {
+ util.GetLogger(ctx).WithError(err).Errorf("db.SharedUsers failed: %s", err)
// default to all users so we do needless queries rather than miss some important device update
- return nil, usersWithChangedKeys
+ return sharedUsersMap
}
for _, userID := range sharedUsers {
sharedUsersMap[userID]++
}
- // We forcibly put ourselves in this list because we should be notified about our own device updates
- // and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
- // be notified about key changes.
- sharedUsersMap[userID] = 1
- return sharedUsersMap, sharedUsers
+ return sharedUsersMap
}
func joinedRooms(res *types.Response, userID string) []string {
diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go
index 79ed440e..6bfc91ed 100644
--- a/syncapi/internal/keychange_test.go
+++ b/syncapi/internal/keychange_test.go
@@ -129,6 +129,7 @@ type wantCatchup struct {
}
func assertCatchup(t *testing.T, hasNew bool, syncResponse *types.Response, want wantCatchup) {
+ t.Helper()
if hasNew != want.hasNew {
t.Errorf("got hasNew=%v want %v", hasNew, want.hasNew)
}
diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go
index d13b7be4..58f40451 100644
--- a/syncapi/storage/postgres/current_room_state_table.go
+++ b/syncapi/storage/postgres/current_room_state_table.go
@@ -112,7 +112,7 @@ const selectEventsWithEventIDsSQL = "" +
const selectSharedUsersSQL = "" +
"SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
- ") AND state_key = ANY($2) AND membership='join';"
+ ") AND state_key = ANY($2) AND membership IN ('join', 'invite');"
type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt
@@ -407,7 +407,7 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectSharedUsersStmt)
- rows, err := stmt.QueryContext(ctx, userID, otherUserIDs)
+ rows, err := stmt.QueryContext(ctx, userID, pq.Array(otherUserIDs))
if err != nil {
return nil, err
}
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
index e19298ae..3a10b232 100644
--- a/syncapi/storage/sqlite3/current_room_state_table.go
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -94,9 +94,9 @@ const selectEventsWithEventIDsSQL = "" +
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
const selectSharedUsersSQL = "" +
- "SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
+ "SELECT state_key FROM syncapi_current_room_state WHERE room_id IN(" +
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
- ") AND state_key IN ($2) AND membership='join';"
+ ") AND state_key IN ($2) AND membership IN ('join', 'invite');"
type currentRoomStateStatements struct {
db *sql.DB
@@ -420,25 +420,28 @@ func (s *currentRoomStateStatements) SelectStateEvent(
func (s *currentRoomStateStatements) SelectSharedUsers(
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
) ([]string, error) {
- query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
- stmt, err := s.db.Prepare(query)
- if err != nil {
- return nil, fmt.Errorf("SelectSharedUsers s.db.Prepare: %w", err)
- }
- defer internal.CloseAndLogIfError(ctx, stmt, "SelectSharedUsers: stmt.close() failed")
- rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, userID, otherUserIDs)
- if err != nil {
- return nil, err
+
+ params := make([]interface{}, len(otherUserIDs)+1)
+ params[0] = userID
+ for k, v := range otherUserIDs {
+ params[k+1] = v
}
- defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed")
- var stateKey string
result := make([]string, 0, len(otherUserIDs))
- for rows.Next() {
- if err := rows.Scan(&stateKey); err != nil {
- return nil, err
- }
- result = append(result, stateKey)
- }
- return result, rows.Err()
+ query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
+ err := sqlutil.RunLimitedVariablesQuery(
+ ctx, query, s.db, params, sqlutil.SQLite3MaxVariables,
+ func(rows *sql.Rows) error {
+ var stateKey string
+ for rows.Next() {
+ if err := rows.Scan(&stateKey); err != nil {
+ return err
+ }
+ result = append(result, stateKey)
+ }
+ return nil
+ },
+ )
+
+ return result, err
}