aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/query/query_room_hierarchy.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/internal/query/query_room_hierarchy.go')
-rw-r--r--roomserver/internal/query/query_room_hierarchy.go77
1 files changed, 51 insertions, 26 deletions
diff --git a/roomserver/internal/query/query_room_hierarchy.go b/roomserver/internal/query/query_room_hierarchy.go
index 5f55980f..3fc61319 100644
--- a/roomserver/internal/query/query_room_hierarchy.go
+++ b/roomserver/internal/query/query_room_hierarchy.go
@@ -39,9 +39,14 @@ import (
//
// If returned walker is nil, then there are no more rooms left to traverse. This method does not modify the provided walker, so it
// can be cached.
-func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker roomserver.RoomHierarchyWalker, limit int) ([]fclient.RoomHierarchyRoom, *roomserver.RoomHierarchyWalker, error) {
- if authorised, _ := authorised(ctx, querier, walker.Caller, walker.RootRoomID, nil); !authorised {
- return nil, nil, roomserver.ErrRoomUnknownOrNotAllowed{Err: fmt.Errorf("room is unknown/forbidden")}
+func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker roomserver.RoomHierarchyWalker, limit int) (
+ []fclient.RoomHierarchyRoom,
+ []string,
+ *roomserver.RoomHierarchyWalker,
+ error,
+) {
+ if authorised, _, _ := authorised(ctx, querier, walker.Caller, walker.RootRoomID, nil); !authorised {
+ return nil, []string{walker.RootRoomID.String()}, nil, roomserver.ErrRoomUnknownOrNotAllowed{Err: fmt.Errorf("room is unknown/forbidden")}
}
discoveredRooms := []fclient.RoomHierarchyRoom{}
@@ -50,6 +55,7 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
unvisited := make([]roomserver.RoomHierarchyWalkerQueuedRoom, len(walker.Unvisited))
copy(unvisited, walker.Unvisited)
processed := walker.Processed.Copy()
+ inaccessible := []string{}
// Depth first -> stack data structure
for len(unvisited) > 0 {
@@ -108,7 +114,7 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
// as these children may be rooms we do know about.
roomType = spec.MSpace
}
- } else if authorised, isJoinedOrInvited := authorised(ctx, querier, walker.Caller, queuedRoom.RoomID, queuedRoom.ParentRoomID); authorised {
+ } else if authorised, isJoinedOrInvited, allowedRoomIDs := authorised(ctx, querier, walker.Caller, queuedRoom.RoomID, queuedRoom.ParentRoomID); authorised {
// Get all `m.space.child` state events for this room
events, err := childReferences(ctx, querier, walker.SuggestedOnly, queuedRoom.RoomID)
if err != nil {
@@ -125,14 +131,18 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
}
discoveredRooms = append(discoveredRooms, fclient.RoomHierarchyRoom{
- PublicRoom: *pubRoom,
- RoomType: roomType,
- ChildrenState: events,
+ PublicRoom: *pubRoom,
+ RoomType: roomType,
+ ChildrenState: events,
+ AllowedRoomIDs: allowedRoomIDs,
})
// don't walk children if the user is not joined/invited to the space
if !isJoinedOrInvited {
continue
}
+ } else if !authorised {
+ inaccessible = append(inaccessible, queuedRoom.RoomID.String())
+ continue
} else {
// room exists but user is not authorised
continue
@@ -149,6 +159,7 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
// We need to invert the order here because the child events are lo->hi on the timestamp,
// so we need to ensure we pop in the same lo->hi order, which won't be the case if we
// insert the highest timestamp last in a stack.
+ extendQueueLoop:
for i := len(discoveredChildEvents) - 1; i >= 0; i-- {
spaceContent := struct {
Via []string `json:"via"`
@@ -161,6 +172,12 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
if err != nil {
util.GetLogger(ctx).WithError(err).WithField("invalid_room_id", ev.StateKey).WithField("parent_room_id", queuedRoom.RoomID).Warn("Invalid room ID in m.space.child state event")
} else {
+ // Make sure not to queue inaccessible rooms
+ for _, inaccessibleRoomID := range inaccessible {
+ if inaccessibleRoomID == childRoomID.String() {
+ continue extendQueueLoop
+ }
+ }
unvisited = append(unvisited, roomserver.RoomHierarchyWalkerQueuedRoom{
RoomID: *childRoomID,
ParentRoomID: &queuedRoom.RoomID,
@@ -173,7 +190,7 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
if len(unvisited) == 0 {
// If no more rooms to walk, then don't return a walker for future pages
- return discoveredRooms, nil, nil
+ return discoveredRooms, inaccessible, nil, nil
} else {
// If there are more rooms to walk, then return a new walker to resume walking from (for querying more pages)
newWalker := roomserver.RoomHierarchyWalker{
@@ -185,22 +202,25 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
Processed: processed,
}
- return discoveredRooms, &newWalker, nil
+ return discoveredRooms, inaccessible, &newWalker, nil
}
}
// authorised returns true iff the user is joined this room or the room is world_readable
-func authorised(ctx context.Context, querier *Queryer, caller types.DeviceOrServerName, roomID spec.RoomID, parentRoomID *spec.RoomID) (authed, isJoinedOrInvited bool) {
+func authorised(ctx context.Context, querier *Queryer, caller types.DeviceOrServerName, roomID spec.RoomID, parentRoomID *spec.RoomID) (authed, isJoinedOrInvited bool, resultAllowedRoomIDs []string) {
if clientCaller := caller.Device(); clientCaller != nil {
return authorisedUser(ctx, querier, clientCaller, roomID, parentRoomID)
- } else {
- return authorisedServer(ctx, querier, roomID, *caller.ServerName()), false
}
+ if serverCaller := caller.ServerName(); serverCaller != nil {
+ authed, resultAllowedRoomIDs = authorisedServer(ctx, querier, roomID, *serverCaller)
+ return authed, false, resultAllowedRoomIDs
+ }
+ return false, false, resultAllowedRoomIDs
}
// authorisedServer returns true iff the server is joined this room or the room is world_readable, public, or knockable
-func authorisedServer(ctx context.Context, querier *Queryer, roomID spec.RoomID, callerServerName spec.ServerName) bool {
+func authorisedServer(ctx context.Context, querier *Queryer, roomID spec.RoomID, callerServerName spec.ServerName) (bool, []string) {
// Check history visibility / join rules first
hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomHistoryVisibility,
@@ -219,13 +239,13 @@ func authorisedServer(ctx context.Context, querier *Queryer, roomID spec.RoomID,
}, &queryRoomRes)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to QueryCurrentState")
- return false
+ return false, []string{}
}
hisVisEv := queryRoomRes.StateEvents[hisVisTuple]
if hisVisEv != nil {
hisVis, _ := hisVisEv.HistoryVisibility()
if hisVis == "world_readable" {
- return true
+ return true, []string{}
}
}
@@ -238,19 +258,23 @@ func authorisedServer(ctx context.Context, querier *Queryer, roomID spec.RoomID,
rule, ruleErr := joinRuleEv.JoinRule()
if ruleErr != nil {
util.GetLogger(ctx).WithError(ruleErr).WithField("parent_room_id", roomID).Warn("failed to get join rule")
- return false
+ return false, []string{}
}
if rule == spec.Public || rule == spec.Knock {
- return true
+ return true, []string{}
}
- if rule == spec.Restricted {
+ if rule == spec.Restricted || rule == spec.KnockRestricted {
allowJoinedToRoomIDs = append(allowJoinedToRoomIDs, restrictedJoinRuleAllowedRooms(ctx, joinRuleEv)...)
}
}
// check if server is joined to any allowed room
+ resultAllowedRoomIDs := make([]string, 0, len(allowJoinedToRoomIDs))
+ for _, allowedRoomID := range allowJoinedToRoomIDs {
+ resultAllowedRoomIDs = append(resultAllowedRoomIDs, allowedRoomID.String())
+ }
for _, allowedRoomID := range allowJoinedToRoomIDs {
var queryRes fs.QueryJoinedHostServerNamesInRoomResponse
err = querier.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, &fs.QueryJoinedHostServerNamesInRoomRequest{
@@ -262,18 +286,18 @@ func authorisedServer(ctx context.Context, querier *Queryer, roomID spec.RoomID,
}
for _, srv := range queryRes.ServerNames {
if srv == callerServerName {
- return true
+ return true, resultAllowedRoomIDs[1:]
}
}
}
- return false
+ return false, resultAllowedRoomIDs[1:]
}
// authorisedUser returns true iff the user is invited/joined this room or the room is world_readable
// or if the room has a public or knock join rule.
// Failing that, if the room has a restricted join rule and belongs to the space parent listed, it will return true.
-func authorisedUser(ctx context.Context, querier *Queryer, clientCaller *userapi.Device, roomID spec.RoomID, parentRoomID *spec.RoomID) (authed bool, isJoinedOrInvited bool) {
+func authorisedUser(ctx context.Context, querier *Queryer, clientCaller *userapi.Device, roomID spec.RoomID, parentRoomID *spec.RoomID) (authed bool, isJoinedOrInvited bool, resultAllowedRoomIDs []string) {
hisVisTuple := gomatrixserverlib.StateKeyTuple{
EventType: spec.MRoomHistoryVisibility,
StateKey: "",
@@ -295,20 +319,20 @@ func authorisedUser(ctx context.Context, querier *Queryer, clientCaller *userapi
}, &queryRes)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("failed to QueryCurrentState")
- return false, false
+ return false, false, resultAllowedRoomIDs
}
memberEv := queryRes.StateEvents[roomMemberTuple]
if memberEv != nil {
membership, _ := memberEv.Membership()
if membership == spec.Join || membership == spec.Invite {
- return true, true
+ return true, true, resultAllowedRoomIDs
}
}
hisVisEv := queryRes.StateEvents[hisVisTuple]
if hisVisEv != nil {
hisVis, _ := hisVisEv.HistoryVisibility()
if hisVis == "world_readable" {
- return true, false
+ return true, false, resultAllowedRoomIDs
}
}
joinRuleEv := queryRes.StateEvents[joinRuleTuple]
@@ -323,6 +347,7 @@ func authorisedUser(ctx context.Context, querier *Queryer, clientCaller *userapi
allowedRoomIDs := restrictedJoinRuleAllowedRooms(ctx, joinRuleEv)
// check parent is in the allowed set
for _, a := range allowedRoomIDs {
+ resultAllowedRoomIDs = append(resultAllowedRoomIDs, a.String())
if *parentRoomID == a {
allowed = true
break
@@ -345,13 +370,13 @@ func authorisedUser(ctx context.Context, querier *Queryer, clientCaller *userapi
if memberEv != nil {
membership, _ := memberEv.Membership()
if membership == spec.Join {
- return true, false
+ return true, false, resultAllowedRoomIDs
}
}
}
}
}
- return false, false
+ return false, false, resultAllowedRoomIDs
}
// helper function to fetch a state event