aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-07-23 12:26:31 +0100
committerGitHub <noreply@github.com>2020-07-23 12:26:31 +0100
commit7b862384a779f067f07ffeb2151856f89d372732 (patch)
treee817330aae6a87325c584031cd9a4cc0f9024dea
parentcfeb1b2f4281c9b9b420f9ed9166cd4d8e549288 (diff)
currentstate: Add QuerySharedUsers (#1217)
This will be used to determine who to send device list updates to. It can also be used to determine who to send presence info to.
-rw-r--r--currentstateserver/api/api.go10
-rw-r--r--currentstateserver/currentstateserver_test.go111
-rw-r--r--currentstateserver/internal/api.go13
-rw-r--r--currentstateserver/inthttp/client.go11
-rw-r--r--currentstateserver/inthttp/server.go13
-rw-r--r--currentstateserver/storage/interface.go2
-rw-r--r--currentstateserver/storage/postgres/current_room_state_table.go38
-rw-r--r--currentstateserver/storage/shared/storage.go4
-rw-r--r--currentstateserver/storage/sqlite3/current_room_state_table.go41
-rw-r--r--currentstateserver/storage/tables/interface.go2
10 files changed, 232 insertions, 13 deletions
diff --git a/currentstateserver/api/api.go b/currentstateserver/api/api.go
index 729a66ba..520ce8d6 100644
--- a/currentstateserver/api/api.go
+++ b/currentstateserver/api/api.go
@@ -31,6 +31,16 @@ type CurrentStateInternalAPI interface {
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
// QueryBulkStateContent does a bulk query for state event content in the given rooms.
QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error
+ // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
+ QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
+}
+
+type QuerySharedUsersRequest struct {
+ UserID string
+}
+
+type QuerySharedUsersResponse struct {
+ UserIDs []string
}
type QueryRoomsForUserRequest struct {
diff --git a/currentstateserver/currentstateserver_test.go b/currentstateserver/currentstateserver_test.go
index a0627fea..4dac742f 100644
--- a/currentstateserver/currentstateserver_test.go
+++ b/currentstateserver/currentstateserver_test.go
@@ -16,9 +16,11 @@ package currentstateserver
import (
"context"
+ "crypto/ed25519"
"encoding/json"
"net/http"
"reflect"
+ "sort"
"testing"
"time"
@@ -178,3 +180,112 @@ func TestQueryCurrentState(t *testing.T) {
runCases(currStateAPI)
})
}
+
+func mustMakeMembershipEvent(t *testing.T, roomID, userID, membership string) *roomserverAPI.OutputNewRoomEvent {
+ eb := gomatrixserverlib.EventBuilder{
+ RoomID: roomID,
+ Sender: userID,
+ StateKey: &userID,
+ Type: "m.room.member",
+ Content: []byte(`{"membership":"` + membership + `"}`),
+ }
+ _, pkey, err := ed25519.GenerateKey(nil)
+ if err != nil {
+ t.Fatalf("failed to make ed25519 key: %s", err)
+ }
+ roomVer := gomatrixserverlib.RoomVersionV5
+ ev, err := eb.Build(
+ time.Now(), gomatrixserverlib.ServerName("localhost"), gomatrixserverlib.KeyID("ed25519:test"),
+ pkey, roomVer,
+ )
+ if err != nil {
+ t.Fatalf("mustMakeMembershipEvent failed: %s", err)
+ }
+
+ return &roomserverAPI.OutputNewRoomEvent{
+ Event: ev.Headered(roomVer),
+ AddsStateEventIDs: []string{ev.EventID()},
+ }
+}
+
+// This test makes sure that QuerySharedUsers is returning the correct users for a range of sets.
+func TestQuerySharedUsers(t *testing.T) {
+ currStateAPI, producer := MustMakeInternalAPI(t)
+ MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@alice:localhost", "join"))
+ MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo:bar", "@bob:localhost", "join"))
+
+ MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo2:bar", "@alice:localhost", "join"))
+ MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo2:bar", "@charlie:localhost", "join"))
+
+ MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@alice:localhost", "join"))
+ MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@bob:localhost", "join"))
+ MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo3:bar", "@dave:localhost", "leave"))
+
+ MustWriteOutputEvent(t, producer, mustMakeMembershipEvent(t, "!foo4:bar", "@alice:localhost", "join"))
+
+ testCases := []struct {
+ req api.QuerySharedUsersRequest
+ wantRes api.QuerySharedUsersResponse
+ }{
+ // Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A,B,C)
+ {
+ req: api.QuerySharedUsersRequest{
+ UserID: "@alice:localhost",
+ },
+ wantRes: api.QuerySharedUsersResponse{
+ UserIDs: []string{"@alice:localhost", "@bob:localhost", "@charlie:localhost"},
+ },
+ },
+
+ // Unknown user has no shared users
+ {
+ req: api.QuerySharedUsersRequest{
+ UserID: "@unknownuser:localhost",
+ },
+ wantRes: api.QuerySharedUsersResponse{
+ UserIDs: nil,
+ },
+ },
+
+ // left real user produces no shared users
+ {
+ req: api.QuerySharedUsersRequest{
+ UserID: "@dave:localhost",
+ },
+ wantRes: api.QuerySharedUsersResponse{
+ UserIDs: nil,
+ },
+ },
+ }
+
+ runCases := func(testAPI api.CurrentStateInternalAPI) {
+ for _, tc := range testCases {
+ var res api.QuerySharedUsersResponse
+ err := testAPI.QuerySharedUsers(context.Background(), &tc.req, &res)
+ if err != nil {
+ t.Errorf("QuerySharedUsers returned error: %s", err)
+ continue
+ }
+ sort.Strings(res.UserIDs)
+ sort.Strings(tc.wantRes.UserIDs)
+ if !reflect.DeepEqual(res.UserIDs, tc.wantRes.UserIDs) {
+ t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDs, tc.wantRes.UserIDs)
+ }
+ }
+ }
+
+ t.Run("HTTP API", func(t *testing.T) {
+ router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
+ AddInternalRoutes(router, currStateAPI)
+ apiURL, cancel := test.ListenAndServe(t, router, false)
+ defer cancel()
+ httpAPI, err := inthttp.NewCurrentStateAPIClient(apiURL, &http.Client{})
+ if err != nil {
+ t.Fatalf("failed to create HTTP client")
+ }
+ runCases(httpAPI)
+ })
+ t.Run("Monolith", func(t *testing.T) {
+ runCases(currStateAPI)
+ })
+}
diff --git a/currentstateserver/internal/api.go b/currentstateserver/internal/api.go
index c2876047..e945d0c1 100644
--- a/currentstateserver/internal/api.go
+++ b/currentstateserver/internal/api.go
@@ -68,3 +68,16 @@ func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req
}
return nil
}
+
+func (a *CurrentStateInternalAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
+ roomIDs, err := a.DB.GetRoomsByMembership(ctx, req.UserID, "join")
+ if err != nil {
+ return err
+ }
+ users, err := a.DB.JoinedUsersSetInRooms(ctx, roomIDs)
+ if err != nil {
+ return err
+ }
+ res.UserIDs = users
+ return nil
+}
diff --git a/currentstateserver/inthttp/client.go b/currentstateserver/inthttp/client.go
index b8c6a119..cce881ff 100644
--- a/currentstateserver/inthttp/client.go
+++ b/currentstateserver/inthttp/client.go
@@ -29,6 +29,7 @@ const (
QueryCurrentStatePath = "/currentstateserver/queryCurrentState"
QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser"
QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent"
+ QuerySharedUsersPath = "/currentstateserver/querySharedUsers"
)
// NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API.
@@ -86,3 +87,13 @@ func (h *httpCurrentStateInternalAPI) QueryBulkStateContent(
apiURL := h.apiURL + QueryBulkStateContentPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
+
+func (h *httpCurrentStateInternalAPI) QuerySharedUsers(
+ ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers")
+ defer span.Finish()
+
+ apiURL := h.apiURL + QuerySharedUsersPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+}
diff --git a/currentstateserver/inthttp/server.go b/currentstateserver/inthttp/server.go
index dafb9f64..f4e93dcd 100644
--- a/currentstateserver/inthttp/server.go
+++ b/currentstateserver/inthttp/server.go
@@ -64,4 +64,17 @@ func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
+ internalAPIMux.Handle(QuerySharedUsersPath,
+ httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse {
+ request := api.QuerySharedUsersRequest{}
+ response := api.QuerySharedUsersResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := intAPI.QuerySharedUsers(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
}
diff --git a/currentstateserver/storage/interface.go b/currentstateserver/storage/interface.go
index 0e95cde8..1c4635be 100644
--- a/currentstateserver/storage/interface.go
+++ b/currentstateserver/storage/interface.go
@@ -37,4 +37,6 @@ type Database interface {
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// Redact a state event
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error
+ // JoinedUsersSetInRooms returns all joined users in the rooms given.
+ JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error)
}
diff --git a/currentstateserver/storage/postgres/current_room_state_table.go b/currentstateserver/storage/postgres/current_room_state_table.go
index 79c9f967..9e0070f1 100644
--- a/currentstateserver/storage/postgres/current_room_state_table.go
+++ b/currentstateserver/storage/postgres/current_room_state_table.go
@@ -77,14 +77,18 @@ const selectBulkStateContentSQL = "" +
const selectBulkStateContentWildSQL = "" +
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)"
+const selectJoinedUsersSetForRoomsSQL = "" +
+ "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = 'm.room.member' and content_value = 'join'"
+
type currentRoomStateStatements struct {
- upsertRoomStateStmt *sql.Stmt
- deleteRoomStateByEventIDStmt *sql.Stmt
- selectRoomIDsWithMembershipStmt *sql.Stmt
- selectEventsWithEventIDsStmt *sql.Stmt
- selectStateEventStmt *sql.Stmt
- selectBulkStateContentStmt *sql.Stmt
- selectBulkStateContentWildStmt *sql.Stmt
+ upsertRoomStateStmt *sql.Stmt
+ deleteRoomStateByEventIDStmt *sql.Stmt
+ selectRoomIDsWithMembershipStmt *sql.Stmt
+ selectEventsWithEventIDsStmt *sql.Stmt
+ selectStateEventStmt *sql.Stmt
+ selectBulkStateContentStmt *sql.Stmt
+ selectBulkStateContentWildStmt *sql.Stmt
+ selectJoinedUsersSetForRoomsStmt *sql.Stmt
}
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@@ -114,9 +118,29 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectBulkStateContentWildStmt, err = db.Prepare(selectBulkStateContentWildSQL); err != nil {
return nil, err
}
+ if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
+ return nil, err
+ }
return s, nil
}
+func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) {
+ rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
+ var userIDs []string
+ for rows.Next() {
+ var userID string
+ if err := rows.Scan(&userID); err != nil {
+ return nil, err
+ }
+ userIDs = append(userIDs, userID)
+ }
+ return userIDs, rows.Err()
+}
+
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
ctx context.Context,
diff --git a/currentstateserver/storage/shared/storage.go b/currentstateserver/storage/shared/storage.go
index 66b979d8..aafb5fdd 100644
--- a/currentstateserver/storage/shared/storage.go
+++ b/currentstateserver/storage/shared/storage.go
@@ -85,3 +85,7 @@ func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatr
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
}
+
+func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error) {
+ return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs)
+}
diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go
index b95fb435..4d3803b6 100644
--- a/currentstateserver/storage/sqlite3/current_room_state_table.go
+++ b/currentstateserver/storage/sqlite3/current_room_state_table.go
@@ -66,13 +66,17 @@ const selectBulkStateContentSQL = "" +
const selectBulkStateContentWildSQL = "" +
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)"
+const selectJoinedUsersSetForRoomsSQL = "" +
+ "SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join'"
+
type currentRoomStateStatements struct {
- db *sql.DB
- writer *sqlutil.TransactionWriter
- upsertRoomStateStmt *sql.Stmt
- deleteRoomStateByEventIDStmt *sql.Stmt
- selectRoomIDsWithMembershipStmt *sql.Stmt
- selectStateEventStmt *sql.Stmt
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
+ upsertRoomStateStmt *sql.Stmt
+ deleteRoomStateByEventIDStmt *sql.Stmt
+ selectRoomIDsWithMembershipStmt *sql.Stmt
+ selectStateEventStmt *sql.Stmt
+ selectJoinedUsersSetForRoomsStmt *sql.Stmt
}
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@@ -96,9 +100,34 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error)
if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
return nil, err
}
+ if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
+ return nil, err
+ }
return s, nil
}
+func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) {
+ iRoomIDs := make([]interface{}, len(roomIDs))
+ for i, v := range roomIDs {
+ iRoomIDs[i] = v
+ }
+ query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
+ rows, err := s.db.QueryContext(ctx, query, iRoomIDs...)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
+ var userIDs []string
+ for rows.Next() {
+ var userID string
+ if err := rows.Scan(&userID); err != nil {
+ return nil, err
+ }
+ userIDs = append(userIDs, userID)
+ }
+ return userIDs, rows.Err()
+}
+
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
ctx context.Context,
diff --git a/currentstateserver/storage/tables/interface.go b/currentstateserver/storage/tables/interface.go
index 12884b68..88e7a31b 100644
--- a/currentstateserver/storage/tables/interface.go
+++ b/currentstateserver/storage/tables/interface.go
@@ -36,6 +36,8 @@ type CurrentRoomState interface {
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error)
SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error)
+ // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms.
+ SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error)
}
// StrippedEvent represents a stripped event for returning extracted content values.