aboutsummaryrefslogtreecommitdiff
path: root/clientapi
diff options
context:
space:
mode:
Diffstat (limited to 'clientapi')
-rw-r--r--clientapi/clientapi_test.go127
-rw-r--r--clientapi/routing/memberships.go139
-rw-r--r--clientapi/routing/routing.go10
3 files changed, 276 insertions, 0 deletions
diff --git a/clientapi/clientapi_test.go b/clientapi/clientapi_test.go
index 2ff4b650..1b2f1358 100644
--- a/clientapi/clientapi_test.go
+++ b/clientapi/clientapi_test.go
@@ -2151,3 +2151,130 @@ func TestKeyBackup(t *testing.T) {
}
})
}
+
+func TestGetMembership(t *testing.T) {
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+
+ testCases := []struct {
+ name string
+ roomID string
+ user *test.User
+ additionalEvents func(t *testing.T, room *test.Room)
+ request func(t *testing.T, room *test.Room, accessToken string) *http.Request
+ wantOK bool
+ wantMemberCount int
+ }{
+
+ {
+ name: "/joined_members - Bob never joined",
+ user: bob,
+ request: func(t *testing.T, room *test.Room, accessToken string) *http.Request {
+ return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
+ "access_token": accessToken,
+ }))
+ },
+ wantOK: false,
+ },
+ {
+ name: "/joined_members - Alice joined",
+ user: alice,
+ request: func(t *testing.T, room *test.Room, accessToken string) *http.Request {
+ return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
+ "access_token": accessToken,
+ }))
+ },
+ wantOK: true,
+ wantMemberCount: 1,
+ },
+ {
+ name: "/joined_members - Alice leaves, shouldn't be able to see members ",
+ user: alice,
+ request: func(t *testing.T, room *test.Room, accessToken string) *http.Request {
+ return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
+ "access_token": accessToken,
+ }))
+ },
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{
+ "membership": "leave",
+ }, test.WithStateKey(alice.ID))
+ },
+ wantOK: false,
+ },
+ {
+ name: "/joined_members - Bob joins, Alice sees two members",
+ user: alice,
+ request: func(t *testing.T, room *test.Room, accessToken string) *http.Request {
+ return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
+ "access_token": accessToken,
+ }))
+ },
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(bob.ID))
+ },
+ wantOK: true,
+ wantMemberCount: 2,
+ },
+ }
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+
+ cfg, processCtx, close := testrig.CreateConfig(t, dbType)
+ routers := httputil.NewRouters()
+ cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
+ caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
+ defer close()
+ natsInstance := jetstream.NATSInstance{}
+ jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream)
+ defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream)
+
+ // Use an actual roomserver for this
+ rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
+ rsAPI.SetFederationAPI(nil, nil)
+ userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
+
+ // We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
+ AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
+
+ accessTokens := map[*test.User]userDevice{
+ alice: {},
+ bob: {},
+ }
+ createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ room := test.NewRoom(t, alice)
+ t.Cleanup(func() {
+ t.Logf("running cleanup for %s", tc.name)
+ })
+ // inject additional events
+ if tc.additionalEvents != nil {
+ tc.additionalEvents(t, room)
+ }
+ if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Fatalf("failed to send events: %v", err)
+ }
+
+ w := httptest.NewRecorder()
+ routers.Client.ServeHTTP(w, tc.request(t, room, accessTokens[tc.user].accessToken))
+ if w.Code != 200 && tc.wantOK {
+ t.Logf("%s", w.Body.String())
+ t.Fatalf("got HTTP %d want %d", w.Code, 200)
+ }
+ t.Logf("[%s] Resp: %s", tc.name, w.Body.String())
+
+ // check we got the expected events
+ if tc.wantOK {
+ memberCount := len(gjson.GetBytes(w.Body.Bytes(), "joined").Map())
+ if memberCount != tc.wantMemberCount {
+ t.Fatalf("expected %d members, got %d", tc.wantMemberCount, memberCount)
+ }
+ }
+ })
+ }
+ })
+}
diff --git a/clientapi/routing/memberships.go b/clientapi/routing/memberships.go
new file mode 100644
index 00000000..84be498d
--- /dev/null
+++ b/clientapi/routing/memberships.go
@@ -0,0 +1,139 @@
+// Copyright 2024 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package routing
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/matrix-org/dendrite/roomserver/api"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib/spec"
+ "github.com/matrix-org/util"
+)
+
+// https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-rooms-roomid-joined-members
+type getJoinedMembersResponse struct {
+ Joined map[string]joinedMember `json:"joined"`
+}
+
+type joinedMember struct {
+ DisplayName string `json:"display_name"`
+ AvatarURL string `json:"avatar_url"`
+}
+
+// The database stores 'displayname' without an underscore.
+// Deserialize into this and then change to the actual API response
+type databaseJoinedMember struct {
+ DisplayName string `json:"displayname"`
+ AvatarURL string `json:"avatar_url"`
+}
+
+// GetJoinedMembers implements
+//
+// GET /rooms/{roomId}/joined_members
+func GetJoinedMembers(
+ req *http.Request, device *userapi.Device, roomID string,
+ rsAPI api.ClientRoomserverAPI,
+) util.JSONResponse {
+ // Validate the userID
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.InvalidParam("Device UserID is invalid"),
+ }
+ }
+
+ // Validate the roomID
+ validRoomID, err := spec.NewRoomID(roomID)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.InvalidParam("RoomID is invalid"),
+ }
+ }
+
+ // Get the current memberships for the requesting user to determine
+ // if they are allowed to query this endpoint.
+ queryReq := api.QueryMembershipForUserRequest{
+ RoomID: validRoomID.String(),
+ UserID: *userID,
+ }
+
+ var queryRes api.QueryMembershipForUserResponse
+ if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil {
+ util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{},
+ }
+ }
+
+ if !queryRes.HasBeenInRoom {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: spec.Forbidden("You aren't a member of the room and weren't previously a member of the room."),
+ }
+ }
+
+ if !queryRes.IsInRoom {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: spec.Forbidden("You aren't a member of the room and weren't previously a member of the room."),
+ }
+ }
+
+ // Get the current membership events
+ var membershipsForRoomResp api.QueryMembershipsForRoomResponse
+ if err = rsAPI.QueryMembershipsForRoom(req.Context(), &api.QueryMembershipsForRoomRequest{
+ JoinedOnly: true,
+ RoomID: validRoomID.String(),
+ }, &membershipsForRoomResp); err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{},
+ }
+ }
+
+ var res getJoinedMembersResponse
+ res.Joined = make(map[string]joinedMember)
+ for _, ev := range membershipsForRoomResp.JoinEvents {
+ var content databaseJoinedMember
+ if err := json.Unmarshal(ev.Content, &content); err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("failed to unmarshal event content")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{},
+ }
+ }
+
+ userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(ev.Sender))
+ if err != nil || userID == nil {
+ util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{},
+ }
+ }
+
+ res.Joined[userID.String()] = joinedMember(content)
+ }
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: res,
+ }
+}
diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go
index d4aa1d08..3e23ab40 100644
--- a/clientapi/routing/routing.go
+++ b/clientapi/routing/routing.go
@@ -1513,4 +1513,14 @@ func Setup(
return GetPresence(req, device, natsClient, cfg.Matrix.JetStream.Prefixed(jetstream.RequestPresence), vars["userId"])
}),
).Methods(http.MethodGet, http.MethodOptions)
+
+ v3mux.Handle("/rooms/{roomID}/joined_members",
+ httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ return GetJoinedMembers(req, device, vars["roomID"], rsAPI)
+ }),
+ ).Methods(http.MethodGet, http.MethodOptions)
}