aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--clientapi/clientapi_test.go127
-rw-r--r--clientapi/routing/memberships.go139
-rw-r--r--clientapi/routing/routing.go10
-rw-r--r--syncapi/routing/memberships.go62
-rw-r--r--syncapi/routing/routing.go14
-rw-r--r--syncapi/syncapi_test.go33
6 files changed, 278 insertions, 107 deletions
diff --git a/clientapi/clientapi_test.go b/clientapi/clientapi_test.go
index 2ff4b650..1b2f1358 100644
--- a/clientapi/clientapi_test.go
+++ b/clientapi/clientapi_test.go
@@ -2151,3 +2151,130 @@ func TestKeyBackup(t *testing.T) {
}
})
}
+
+func TestGetMembership(t *testing.T) {
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+
+ testCases := []struct {
+ name string
+ roomID string
+ user *test.User
+ additionalEvents func(t *testing.T, room *test.Room)
+ request func(t *testing.T, room *test.Room, accessToken string) *http.Request
+ wantOK bool
+ wantMemberCount int
+ }{
+
+ {
+ name: "/joined_members - Bob never joined",
+ user: bob,
+ request: func(t *testing.T, room *test.Room, accessToken string) *http.Request {
+ return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
+ "access_token": accessToken,
+ }))
+ },
+ wantOK: false,
+ },
+ {
+ name: "/joined_members - Alice joined",
+ user: alice,
+ request: func(t *testing.T, room *test.Room, accessToken string) *http.Request {
+ return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
+ "access_token": accessToken,
+ }))
+ },
+ wantOK: true,
+ wantMemberCount: 1,
+ },
+ {
+ name: "/joined_members - Alice leaves, shouldn't be able to see members ",
+ user: alice,
+ request: func(t *testing.T, room *test.Room, accessToken string) *http.Request {
+ return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
+ "access_token": accessToken,
+ }))
+ },
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{
+ "membership": "leave",
+ }, test.WithStateKey(alice.ID))
+ },
+ wantOK: false,
+ },
+ {
+ name: "/joined_members - Bob joins, Alice sees two members",
+ user: alice,
+ request: func(t *testing.T, room *test.Room, accessToken string) *http.Request {
+ return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
+ "access_token": accessToken,
+ }))
+ },
+ additionalEvents: func(t *testing.T, room *test.Room) {
+ room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, test.WithStateKey(bob.ID))
+ },
+ wantOK: true,
+ wantMemberCount: 2,
+ },
+ }
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+
+ cfg, processCtx, close := testrig.CreateConfig(t, dbType)
+ routers := httputil.NewRouters()
+ cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
+ caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
+ defer close()
+ natsInstance := jetstream.NATSInstance{}
+ jsctx, _ := natsInstance.Prepare(processCtx, &cfg.Global.JetStream)
+ defer jetstream.DeleteAllStreams(jsctx, &cfg.Global.JetStream)
+
+ // Use an actual roomserver for this
+ rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
+ rsAPI.SetFederationAPI(nil, nil)
+ userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff)
+
+ // We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
+ AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
+
+ accessTokens := map[*test.User]userDevice{
+ alice: {},
+ bob: {},
+ }
+ createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ room := test.NewRoom(t, alice)
+ t.Cleanup(func() {
+ t.Logf("running cleanup for %s", tc.name)
+ })
+ // inject additional events
+ if tc.additionalEvents != nil {
+ tc.additionalEvents(t, room)
+ }
+ if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Fatalf("failed to send events: %v", err)
+ }
+
+ w := httptest.NewRecorder()
+ routers.Client.ServeHTTP(w, tc.request(t, room, accessTokens[tc.user].accessToken))
+ if w.Code != 200 && tc.wantOK {
+ t.Logf("%s", w.Body.String())
+ t.Fatalf("got HTTP %d want %d", w.Code, 200)
+ }
+ t.Logf("[%s] Resp: %s", tc.name, w.Body.String())
+
+ // check we got the expected events
+ if tc.wantOK {
+ memberCount := len(gjson.GetBytes(w.Body.Bytes(), "joined").Map())
+ if memberCount != tc.wantMemberCount {
+ t.Fatalf("expected %d members, got %d", tc.wantMemberCount, memberCount)
+ }
+ }
+ })
+ }
+ })
+}
diff --git a/clientapi/routing/memberships.go b/clientapi/routing/memberships.go
new file mode 100644
index 00000000..84be498d
--- /dev/null
+++ b/clientapi/routing/memberships.go
@@ -0,0 +1,139 @@
+// Copyright 2024 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package routing
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/matrix-org/dendrite/roomserver/api"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib/spec"
+ "github.com/matrix-org/util"
+)
+
+// https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-rooms-roomid-joined-members
+type getJoinedMembersResponse struct {
+ Joined map[string]joinedMember `json:"joined"`
+}
+
+type joinedMember struct {
+ DisplayName string `json:"display_name"`
+ AvatarURL string `json:"avatar_url"`
+}
+
+// The database stores 'displayname' without an underscore.
+// Deserialize into this and then change to the actual API response
+type databaseJoinedMember struct {
+ DisplayName string `json:"displayname"`
+ AvatarURL string `json:"avatar_url"`
+}
+
+// GetJoinedMembers implements
+//
+// GET /rooms/{roomId}/joined_members
+func GetJoinedMembers(
+ req *http.Request, device *userapi.Device, roomID string,
+ rsAPI api.ClientRoomserverAPI,
+) util.JSONResponse {
+ // Validate the userID
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.InvalidParam("Device UserID is invalid"),
+ }
+ }
+
+ // Validate the roomID
+ validRoomID, err := spec.NewRoomID(roomID)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.InvalidParam("RoomID is invalid"),
+ }
+ }
+
+ // Get the current memberships for the requesting user to determine
+ // if they are allowed to query this endpoint.
+ queryReq := api.QueryMembershipForUserRequest{
+ RoomID: validRoomID.String(),
+ UserID: *userID,
+ }
+
+ var queryRes api.QueryMembershipForUserResponse
+ if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil {
+ util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{},
+ }
+ }
+
+ if !queryRes.HasBeenInRoom {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: spec.Forbidden("You aren't a member of the room and weren't previously a member of the room."),
+ }
+ }
+
+ if !queryRes.IsInRoom {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: spec.Forbidden("You aren't a member of the room and weren't previously a member of the room."),
+ }
+ }
+
+ // Get the current membership events
+ var membershipsForRoomResp api.QueryMembershipsForRoomResponse
+ if err = rsAPI.QueryMembershipsForRoom(req.Context(), &api.QueryMembershipsForRoomRequest{
+ JoinedOnly: true,
+ RoomID: validRoomID.String(),
+ }, &membershipsForRoomResp); err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryEventsByID failed")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{},
+ }
+ }
+
+ var res getJoinedMembersResponse
+ res.Joined = make(map[string]joinedMember)
+ for _, ev := range membershipsForRoomResp.JoinEvents {
+ var content databaseJoinedMember
+ if err := json.Unmarshal(ev.Content, &content); err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("failed to unmarshal event content")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{},
+ }
+ }
+
+ userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(ev.Sender))
+ if err != nil || userID == nil {
+ util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{},
+ }
+ }
+
+ res.Joined[userID.String()] = joinedMember(content)
+ }
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: res,
+ }
+}
diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go
index d4aa1d08..3e23ab40 100644
--- a/clientapi/routing/routing.go
+++ b/clientapi/routing/routing.go
@@ -1513,4 +1513,14 @@ func Setup(
return GetPresence(req, device, natsClient, cfg.Matrix.JetStream.Prefixed(jetstream.RequestPresence), vars["userId"])
}),
).Methods(http.MethodGet, http.MethodOptions)
+
+ v3mux.Handle("/rooms/{roomID}/joined_members",
+ httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ return GetJoinedMembers(req, device, vars["roomID"], rsAPI)
+ }),
+ ).Methods(http.MethodGet, http.MethodOptions)
}
diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go
index e849adf6..9cc937d8 100644
--- a/syncapi/routing/memberships.go
+++ b/syncapi/routing/memberships.go
@@ -15,7 +15,6 @@
package routing
import (
- "encoding/json"
"math"
"net/http"
@@ -33,31 +32,13 @@ type getMembershipResponse struct {
Chunk []synctypes.ClientEvent `json:"chunk"`
}
-// https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-rooms-roomid-joined-members
-type getJoinedMembersResponse struct {
- Joined map[string]joinedMember `json:"joined"`
-}
-
-type joinedMember struct {
- DisplayName string `json:"display_name"`
- AvatarURL string `json:"avatar_url"`
-}
-
-// The database stores 'displayname' without an underscore.
-// Deserialize into this and then change to the actual API response
-type databaseJoinedMember struct {
- DisplayName string `json:"displayname"`
- AvatarURL string `json:"avatar_url"`
-}
-
// GetMemberships implements
//
// GET /rooms/{roomId}/members
-// GET /rooms/{roomId}/joined_members
func GetMemberships(
req *http.Request, device *userapi.Device, roomID string,
syncDB storage.Database, rsAPI api.SyncRoomserverAPI,
- joinedOnly bool, membership, notMembership *string, at string,
+ membership, notMembership *string, at string,
) util.JSONResponse {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
@@ -87,13 +68,6 @@ func GetMemberships(
}
}
- if joinedOnly && !queryRes.IsInRoom {
- return util.JSONResponse{
- Code: http.StatusForbidden,
- JSON: spec.Forbidden("You aren't a member of the room and weren't previously a member of the room."),
- }
- }
-
db, err := syncDB.NewDatabaseSnapshot(req.Context())
if err != nil {
return util.JSONResponse{
@@ -139,40 +113,6 @@ func GetMemberships(
result := qryRes.Events
- if joinedOnly {
- var res getJoinedMembersResponse
- res.Joined = make(map[string]joinedMember)
- for _, ev := range result {
- var content databaseJoinedMember
- if err := json.Unmarshal(ev.Content(), &content); err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("failed to unmarshal event content")
- return util.JSONResponse{
- Code: http.StatusInternalServerError,
- JSON: spec.InternalServerError{},
- }
- }
-
- userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID())
- if err != nil || userID == nil {
- util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed")
- return util.JSONResponse{
- Code: http.StatusInternalServerError,
- JSON: spec.InternalServerError{},
- }
- }
- if err != nil {
- return util.JSONResponse{
- Code: http.StatusForbidden,
- JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"),
- }
- }
- res.Joined[userID.String()] = joinedMember(content)
- }
- return util.JSONResponse{
- Code: http.StatusOK,
- JSON: res,
- }
- }
return util.JSONResponse{
Code: http.StatusOK,
JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go
index a837e169..78188d1b 100644
--- a/syncapi/routing/routing.go
+++ b/syncapi/routing/routing.go
@@ -197,19 +197,7 @@ func Setup(
}
at := req.URL.Query().Get("at")
- return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, false, membership, notMembership, at)
+ return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, membership, notMembership, at)
}, httputil.WithAllowGuests()),
).Methods(http.MethodGet, http.MethodOptions)
-
- v3mux.Handle("/rooms/{roomID}/joined_members",
- httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
- vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
- if err != nil {
- return util.ErrorResponse(err)
- }
- at := req.URL.Query().Get("at")
- membership := spec.Join
- return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, true, &membership, nil, at)
- }),
- ).Methods(http.MethodGet, http.MethodOptions)
}
diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go
index ac526851..0a2c38ab 100644
--- a/syncapi/syncapi_test.go
+++ b/syncapi/syncapi_test.go
@@ -754,24 +754,6 @@ func TestGetMembership(t *testing.T) {
wantOK: false,
},
{
- name: "/joined_members - Bob never joined",
- request: func(t *testing.T, room *test.Room) *http.Request {
- return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
- "access_token": bobDev.AccessToken,
- }))
- },
- wantOK: false,
- },
- {
- name: "/joined_members - Alice joined",
- request: func(t *testing.T, room *test.Room) *http.Request {
- return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
- "access_token": aliceDev.AccessToken,
- }))
- },
- wantOK: true,
- },
- {
name: "Alice leaves before Bob joins, should not be able to see Bob",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
@@ -810,21 +792,6 @@ func TestGetMembership(t *testing.T) {
wantMemberCount: 2,
},
{
- name: "/joined_members - Alice leaves, shouldn't be able to see members ",
- request: func(t *testing.T, room *test.Room) *http.Request {
- return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
- "access_token": aliceDev.AccessToken,
- }))
- },
- additionalEvents: func(t *testing.T, room *test.Room) {
- room.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{
- "membership": "leave",
- }, test.WithStateKey(alice.ID))
- },
- useSleep: true,
- wantOK: false,
- },
- {
name: "'at' specified, returns memberships before Bob joins",
request: func(t *testing.T, room *test.Room) *http.Request {
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{