aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--currentstateserver/api/api.go6
-rw-r--r--currentstateserver/currentstateserver_test.go61
-rw-r--r--currentstateserver/internal/api.go19
-rw-r--r--currentstateserver/storage/interface.go4
-rw-r--r--currentstateserver/storage/postgres/current_room_state_table.go14
-rw-r--r--currentstateserver/storage/shared/storage.go2
-rw-r--r--currentstateserver/storage/sqlite3/current_room_state_table.go13
-rw-r--r--currentstateserver/storage/tables/interface.go5
8 files changed, 95 insertions, 29 deletions
diff --git a/currentstateserver/api/api.go b/currentstateserver/api/api.go
index 520ce8d6..b778acb2 100644
--- a/currentstateserver/api/api.go
+++ b/currentstateserver/api/api.go
@@ -36,11 +36,13 @@ type CurrentStateInternalAPI interface {
}
type QuerySharedUsersRequest struct {
- UserID string
+ UserID string
+ ExcludeRoomIDs []string
+ IncludeRoomIDs []string
}
type QuerySharedUsersResponse struct {
- UserIDs []string
+ UserIDsToCount map[string]int
}
type QueryRoomsForUserRequest struct {
diff --git a/currentstateserver/currentstateserver_test.go b/currentstateserver/currentstateserver_test.go
index 4dac742f..1366a0be 100644
--- a/currentstateserver/currentstateserver_test.go
+++ b/currentstateserver/currentstateserver_test.go
@@ -20,7 +20,6 @@ import (
"encoding/json"
"net/http"
"reflect"
- "sort"
"testing"
"time"
@@ -227,13 +226,31 @@ func TestQuerySharedUsers(t *testing.T) {
req api.QuerySharedUsersRequest
wantRes api.QuerySharedUsersResponse
}{
- // Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A,B,C)
+ // Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A:4,B:2,C:1)
{
req: api.QuerySharedUsersRequest{
UserID: "@alice:localhost",
},
wantRes: api.QuerySharedUsersResponse{
- UserIDs: []string{"@alice:localhost", "@bob:localhost", "@charlie:localhost"},
+ UserIDsToCount: map[string]int{
+ "@alice:localhost": 4,
+ "@bob:localhost": 2,
+ "@charlie:localhost": 1,
+ },
+ },
+ },
+
+ // Exclude (A,C): sharing (A,B) (A,B) (A) produces (A:3,B:2)
+ {
+ req: api.QuerySharedUsersRequest{
+ UserID: "@alice:localhost",
+ ExcludeRoomIDs: []string{"!foo2:bar"},
+ },
+ wantRes: api.QuerySharedUsersResponse{
+ UserIDsToCount: map[string]int{
+ "@alice:localhost": 3,
+ "@bob:localhost": 2,
+ },
},
},
@@ -243,7 +260,7 @@ func TestQuerySharedUsers(t *testing.T) {
UserID: "@unknownuser:localhost",
},
wantRes: api.QuerySharedUsersResponse{
- UserIDs: nil,
+ UserIDsToCount: map[string]int{},
},
},
@@ -253,7 +270,35 @@ func TestQuerySharedUsers(t *testing.T) {
UserID: "@dave:localhost",
},
wantRes: api.QuerySharedUsersResponse{
- UserIDs: nil,
+ UserIDsToCount: map[string]int{},
+ },
+ },
+
+ // left real user but with included room returns the included room member
+ {
+ req: api.QuerySharedUsersRequest{
+ UserID: "@dave:localhost",
+ IncludeRoomIDs: []string{"!foo:bar"},
+ },
+ wantRes: api.QuerySharedUsersResponse{
+ UserIDsToCount: map[string]int{
+ "@alice:localhost": 1,
+ "@bob:localhost": 1,
+ },
+ },
+ },
+
+ // including a room more than once doesn't double counts
+ {
+ req: api.QuerySharedUsersRequest{
+ UserID: "@dave:localhost",
+ IncludeRoomIDs: []string{"!foo:bar", "!foo:bar", "!foo:bar"},
+ },
+ wantRes: api.QuerySharedUsersResponse{
+ UserIDsToCount: map[string]int{
+ "@alice:localhost": 1,
+ "@bob:localhost": 1,
+ },
},
},
}
@@ -266,10 +311,8 @@ func TestQuerySharedUsers(t *testing.T) {
t.Errorf("QuerySharedUsers returned error: %s", err)
continue
}
- sort.Strings(res.UserIDs)
- sort.Strings(tc.wantRes.UserIDs)
- if !reflect.DeepEqual(res.UserIDs, tc.wantRes.UserIDs) {
- t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDs, tc.wantRes.UserIDs)
+ if !reflect.DeepEqual(res.UserIDsToCount, tc.wantRes.UserIDsToCount) {
+ t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDsToCount, tc.wantRes.UserIDsToCount)
}
}
}
diff --git a/currentstateserver/internal/api.go b/currentstateserver/internal/api.go
index e945d0c1..c581c524 100644
--- a/currentstateserver/internal/api.go
+++ b/currentstateserver/internal/api.go
@@ -74,10 +74,27 @@ func (a *CurrentStateInternalAPI) QuerySharedUsers(ctx context.Context, req *api
if err != nil {
return err
}
+ roomIDs = append(roomIDs, req.IncludeRoomIDs...)
+ excludeMap := make(map[string]bool)
+ for _, roomID := range req.ExcludeRoomIDs {
+ excludeMap[roomID] = true
+ }
+ // filter out excluded rooms
+ j := 0
+ for i := range roomIDs {
+ // move elements to include to the beginning of the slice
+ // then trim elements on the right
+ if !excludeMap[roomIDs[i]] {
+ roomIDs[j] = roomIDs[i]
+ j++
+ }
+ }
+ roomIDs = roomIDs[:j]
+
users, err := a.DB.JoinedUsersSetInRooms(ctx, roomIDs)
if err != nil {
return err
}
- res.UserIDs = users
+ res.UserIDsToCount = users
return nil
}
diff --git a/currentstateserver/storage/interface.go b/currentstateserver/storage/interface.go
index 1c4635be..8deaa348 100644
--- a/currentstateserver/storage/interface.go
+++ b/currentstateserver/storage/interface.go
@@ -37,6 +37,6 @@ type Database interface {
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// Redact a state event
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error
- // JoinedUsersSetInRooms returns all joined users in the rooms given.
- JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error)
+ // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
+ JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
}
diff --git a/currentstateserver/storage/postgres/current_room_state_table.go b/currentstateserver/storage/postgres/current_room_state_table.go
index 9e0070f1..294f757c 100644
--- a/currentstateserver/storage/postgres/current_room_state_table.go
+++ b/currentstateserver/storage/postgres/current_room_state_table.go
@@ -78,7 +78,8 @@ const selectBulkStateContentWildSQL = "" +
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)"
const selectJoinedUsersSetForRoomsSQL = "" +
- "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = 'm.room.member' and content_value = 'join'"
+ "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id = ANY($1) AND" +
+ " type = 'm.room.member' and content_value = 'join' GROUP BY state_key"
type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt
@@ -124,21 +125,22 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
return s, nil
}
-func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) {
+func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
- var userIDs []string
+ result := make(map[string]int)
for rows.Next() {
var userID string
- if err := rows.Scan(&userID); err != nil {
+ var count int
+ if err := rows.Scan(&userID, &count); err != nil {
return nil, err
}
- userIDs = append(userIDs, userID)
+ result[userID] = count
}
- return userIDs, rows.Err()
+ return result, rows.Err()
}
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
diff --git a/currentstateserver/storage/shared/storage.go b/currentstateserver/storage/shared/storage.go
index aafb5fdd..dac38790 100644
--- a/currentstateserver/storage/shared/storage.go
+++ b/currentstateserver/storage/shared/storage.go
@@ -86,6 +86,6 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
}
-func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error) {
+func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs)
}
diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go
index 4d3803b6..5706fa35 100644
--- a/currentstateserver/storage/sqlite3/current_room_state_table.go
+++ b/currentstateserver/storage/sqlite3/current_room_state_table.go
@@ -67,7 +67,7 @@ const selectBulkStateContentWildSQL = "" +
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)"
const selectJoinedUsersSetForRoomsSQL = "" +
- "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join'"
+ "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join' GROUP BY state_key"
type currentRoomStateStatements struct {
db *sql.DB
@@ -106,7 +106,7 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error)
return s, nil
}
-func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) {
+func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
iRoomIDs := make([]interface{}, len(roomIDs))
for i, v := range roomIDs {
iRoomIDs[i] = v
@@ -117,15 +117,16 @@ func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Co
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
- var userIDs []string
+ result := make(map[string]int)
for rows.Next() {
var userID string
- if err := rows.Scan(&userID); err != nil {
+ var count int
+ if err := rows.Scan(&userID, &count); err != nil {
return nil, err
}
- userIDs = append(userIDs, userID)
+ result[userID] = count
}
- return userIDs, rows.Err()
+ return result, rows.Err()
}
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
diff --git a/currentstateserver/storage/tables/interface.go b/currentstateserver/storage/tables/interface.go
index 88e7a31b..121bf4fd 100644
--- a/currentstateserver/storage/tables/interface.go
+++ b/currentstateserver/storage/tables/interface.go
@@ -36,8 +36,9 @@ type CurrentRoomState interface {
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error)
SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error)
- // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms.
- SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error)
+ // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
+ // counts of how many rooms they are joined.
+ SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
}
// StrippedEvent represents a stripped event for returning extracted content values.