aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--roomserver/api/query.go1
-rw-r--r--roomserver/internal/query/query.go2
-rw-r--r--roomserver/storage/interface.go2
-rw-r--r--roomserver/storage/postgres/membership_table.go17
-rw-r--r--roomserver/storage/shared/storage.go4
-rw-r--r--roomserver/storage/sqlite3/membership_table.go23
-rw-r--r--roomserver/storage/tables/interface.go2
-rw-r--r--roomserver/storage/tables/membership_table_test.go2
-rw-r--r--syncapi/consumers/keychange.go6
9 files changed, 36 insertions, 23 deletions
diff --git a/roomserver/api/query.go b/roomserver/api/query.go
index aa7dc473..d63c2478 100644
--- a/roomserver/api/query.go
+++ b/roomserver/api/query.go
@@ -278,6 +278,7 @@ type QuerySharedUsersRequest struct {
OtherUserIDs []string
ExcludeRoomIDs []string
IncludeRoomIDs []string
+ LocalOnly bool
}
type QuerySharedUsersResponse struct {
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index b41a92e9..ee8e1cfe 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -799,7 +799,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser
}
roomIDs = roomIDs[:j]
- users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs)
+ users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs, req.LocalOnly)
if err != nil {
return err
}
diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go
index 43e8da7b..11e175f5 100644
--- a/roomserver/storage/interface.go
+++ b/roomserver/storage/interface.go
@@ -157,7 +157,7 @@ type Database interface {
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
- JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error)
+ JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string, localOnly bool) (map[string]int, error)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go
index bd3fd559..0150534e 100644
--- a/roomserver/storage/postgres/membership_table.go
+++ b/roomserver/storage/postgres/membership_table.go
@@ -68,14 +68,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
- " WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
- " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
+ " WHERE (target_local OR $1 = false)" +
+ " AND room_nid = ANY($2) AND target_nid = ANY($3)" +
+ " AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
+ " AND forgotten = false" +
" GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
- " WHERE room_nid = ANY($1) AND" +
- " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
+ " WHERE (target_local OR $1 = false) " +
+ " AND room_nid = ANY($2)" +
+ " AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
+ " AND forgotten = false" +
" GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the
@@ -334,6 +338,7 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID,
userNIDs []types.EventStateKeyNID,
+ localOnly bool,
) (map[types.EventStateKeyNID]int, error) {
var (
rows *sql.Rows
@@ -342,9 +347,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
if len(userNIDs) > 0 {
stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt)
- rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
+ rows, err = stmt.QueryContext(ctx, localOnly, pq.Array(roomNIDs), pq.Array(userNIDs))
} else {
- rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs))
+ rows, err = stmt.QueryContext(ctx, localOnly, pq.Array(roomNIDs))
}
if err != nil {
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 593abbea..d83a1ff7 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -1280,7 +1280,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
}
// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
-func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) {
+func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string, localOnly bool) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil {
return nil, err
@@ -1295,7 +1295,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
userNIDs = append(userNIDs, nid)
nidToUserID[nid] = id
}
- userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs)
+ userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs, localOnly)
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
index f3303eb0..cd149f0e 100644
--- a/roomserver/storage/sqlite3/membership_table.go
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -44,14 +44,18 @@ const membershipSchema = `
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
- " WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
- " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
+ " WHERE (target_local OR $1 = false)" +
+ " AND room_nid IN ($2) AND target_nid IN ($3)" +
+ " AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
+ " AND forgotten = false" +
" GROUP BY target_nid"
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
- " WHERE room_nid IN ($1) AND " +
- " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
+ " WHERE (target_local OR $1 = false)" +
+ " AND room_nid IN ($2)" +
+ " AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
+ " AND forgotten = false" +
" GROUP BY target_nid"
// Insert a row in to membership table so that it can be locked by the
@@ -305,8 +309,9 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil
}
-func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) {
- params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs))
+func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error) {
+ params := make([]interface{}, 0, 1+len(roomNIDs)+len(userNIDs))
+ params = append(params, localOnly)
for _, v := range roomNIDs {
params = append(params, v)
}
@@ -314,10 +319,10 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
params = append(params, v)
}
- query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
+ query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($2)", sqlutil.QueryVariadicOffset(len(roomNIDs), 1), 1)
if len(userNIDs) > 0 {
- query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
- query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
+ query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(roomNIDs), 1), 1)
+ query = strings.Replace(query, "($3)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)+1), 1)
}
var rows *sql.Rows
var err error
diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go
index 68d30f99..d7bcc95a 100644
--- a/roomserver/storage/tables/interface.go
+++ b/roomserver/storage/tables/interface.go
@@ -137,7 +137,7 @@ type Membership interface {
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error)
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
- SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error)
+ SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID, localOnly bool) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
diff --git a/roomserver/storage/tables/membership_table_test.go b/roomserver/storage/tables/membership_table_test.go
index f789ef4a..c9541d9d 100644
--- a/roomserver/storage/tables/membership_table_test.go
+++ b/roomserver/storage/tables/membership_table_test.go
@@ -79,7 +79,7 @@ func TestMembershipTable(t *testing.T) {
assert.NoError(t, err)
assert.True(t, inRoom)
- userJoinedToRooms, err := tab.SelectJoinedUsersSetForRooms(ctx, nil, []types.RoomNID{1}, userNIDs)
+ userJoinedToRooms, err := tab.SelectJoinedUsersSetForRooms(ctx, nil, []types.RoomNID{1}, userNIDs, false)
assert.NoError(t, err)
assert.Equal(t, 1, len(userJoinedToRooms))
diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go
index dc7d9e20..96ebba7e 100644
--- a/syncapi/consumers/keychange.go
+++ b/syncapi/consumers/keychange.go
@@ -111,7 +111,8 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, d
// work out who we need to notify about the new key
var queryRes roomserverAPI.QuerySharedUsersResponse
err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
- UserID: output.UserID,
+ UserID: output.UserID,
+ LocalOnly: true,
}, &queryRes)
if err != nil {
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")
@@ -135,7 +136,8 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage
// work out who we need to notify about the new key
var queryRes roomserverAPI.QuerySharedUsersResponse
err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
- UserID: output.UserID,
+ UserID: output.UserID,
+ LocalOnly: true,
}, &queryRes)
if err != nil {
logrus.WithError(err).Error("syncapi: failed to QuerySharedUsers for key change event from key server")