aboutsummaryrefslogtreecommitdiff
path: root/federationapi/federationapi_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'federationapi/federationapi_test.go')
-rw-r--r--federationapi/federationapi_test.go22
1 files changed, 13 insertions, 9 deletions
diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go
index c426eb67..4c2a99bb 100644
--- a/federationapi/federationapi_test.go
+++ b/federationapi/federationapi_test.go
@@ -33,7 +33,7 @@ import (
type fedRoomserverAPI struct {
rsapi.FederationRoomserverAPI
inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse)
- queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
+ queryRoomsForUser func(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
}
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
@@ -54,11 +54,11 @@ func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.Input
}
// keychange consumer calls this
-func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
+func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
if f.queryRoomsForUser == nil {
- return nil
+ return nil, nil
}
- return f.queryRoomsForUser(ctx, req, res)
+ return f.queryRoomsForUser(ctx, userID, desiredMembership)
}
// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate
@@ -199,18 +199,22 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID)
room := test.NewRoom(t, creator)
+ roomID, err := spec.NewRoomID(room.ID)
+ if err != nil {
+ t.Fatalf("Invalid room ID: %q", roomID)
+ }
+
rsapi := &fedRoomserverAPI{
inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if req.Asynchronous {
t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous")
}
},
- queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error {
- if req.UserID == joiningUser.ID && req.WantMembership == "join" {
- res.RoomIDs = []string{room.ID}
- return nil
+ queryRoomsForUser: func(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
+ if userID.String() == joiningUser.ID && desiredMembership == "join" {
+ return []spec.RoomID{*roomID}, nil
}
- return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req)
+ return nil, fmt.Errorf("unexpected queryRoomsForUser: %v, %v", userID, desiredMembership)
},
}
fc := &fedClient{