diff options
Diffstat (limited to 'federationapi/federationapi_test.go')
-rw-r--r-- | federationapi/federationapi_test.go | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index c426eb67..4c2a99bb 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -33,7 +33,7 @@ import ( type fedRoomserverAPI struct { rsapi.FederationRoomserverAPI inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) - queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error + queryRoomsForUser func(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) } func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { @@ -54,11 +54,11 @@ func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.Input } // keychange consumer calls this -func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error { +func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) { if f.queryRoomsForUser == nil { - return nil + return nil, nil } - return f.queryRoomsForUser(ctx, req, res) + return f.queryRoomsForUser(ctx, userID, desiredMembership) } // TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate @@ -199,18 +199,22 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) { fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID) room := test.NewRoom(t, creator) + roomID, err := spec.NewRoomID(room.ID) + if err != nil { + t.Fatalf("Invalid room ID: %q", roomID) + } + rsapi := &fedRoomserverAPI{ inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { if req.Asynchronous { t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous") } }, - queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error { - if req.UserID == joiningUser.ID && req.WantMembership == "join" { - res.RoomIDs = []string{room.ID} - return nil + queryRoomsForUser: func(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) { + if userID.String() == joiningUser.ID && desiredMembership == "join" { + return []spec.RoomID{*roomID}, nil } - return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req) + return nil, fmt.Errorf("unexpected queryRoomsForUser: %v, %v", userID, desiredMembership) }, } fc := &fedClient{ |