aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/query
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-04-27 08:07:13 +0200
committerGitHub <noreply@github.com>2023-04-27 08:07:13 +0200
commit2475cf4b61747e76a524af6f71a4eb7e112812af (patch)
treec2446b71a0538fc340a7fb23e8a6c75a48b0a7dd /roomserver/internal/query
parentdd5e47a9a75f717381c27adebdee18aa80a1f256 (diff)
Add some roomserver UTs (#3067)
Adds tests for `QueryRestrictedJoinAllowed`, `IsServerAllowed` and `PerformRoomUpgrade`. Refactors the `QueryRoomVersionForRoom` method to accept a string and return a `gmsl.RoomVersion` instead of req/resp structs. Adds some more caching for `GetStateEvent` This should also fix #2912 by ignoring state events belonging to other users.
Diffstat (limited to 'roomserver/internal/query')
-rw-r--r--roomserver/internal/query/query.go46
1 files changed, 14 insertions, 32 deletions
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index 8a5a9966..6c515dcc 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -521,14 +521,10 @@ func (r *Queryer) QueryMissingEvents(
response.Events = make([]*gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter))
for _, event := range loadedEvents {
if !eventsToFilter[event.EventID()] {
- roomVersion, verr := r.roomVersion(event.RoomID())
- if verr != nil {
- return verr
- }
if _, ok := redactEventIDs[event.EventID()]; ok {
event.Redact()
}
- response.Events = append(response.Events, event.Headered(roomVersion))
+ response.Events = append(response.Events, event.Headered(info.RoomVersion))
}
}
@@ -696,34 +692,20 @@ func GetAuthChain(
}
// QueryRoomVersionForRoom implements api.RoomserverInternalAPI
-func (r *Queryer) QueryRoomVersionForRoom(
- ctx context.Context,
- request *api.QueryRoomVersionForRoomRequest,
- response *api.QueryRoomVersionForRoomResponse,
-) error {
- if roomVersion, ok := r.Cache.GetRoomVersion(request.RoomID); ok {
- response.RoomVersion = roomVersion
- return nil
+func (r *Queryer) QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) {
+ if roomVersion, ok := r.Cache.GetRoomVersion(roomID); ok {
+ return roomVersion, nil
}
- info, err := r.DB.RoomInfo(ctx, request.RoomID)
+ info, err := r.DB.RoomInfo(ctx, roomID)
if err != nil {
- return err
+ return "", err
}
if info == nil {
- return fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", request.RoomID)
+ return "", fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", roomID)
}
- response.RoomVersion = info.RoomVersion
- r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion)
- return nil
-}
-
-func (r *Queryer) roomVersion(roomID string) (gomatrixserverlib.RoomVersion, error) {
- var res api.QueryRoomVersionForRoomResponse
- err := r.QueryRoomVersionForRoom(context.Background(), &api.QueryRoomVersionForRoomRequest{
- RoomID: roomID,
- }, &res)
- return res.RoomVersion, err
+ r.Cache.StoreRoomVersion(roomID, info.RoomVersion)
+ return info.RoomVersion, nil
}
func (r *Queryer) QueryPublishedRooms(
@@ -910,8 +892,8 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
if err = json.Unmarshal(joinRulesEvent.Content(), &joinRules); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
- // If the join rule isn't "restricted" then there's nothing more to do.
- res.Restricted = joinRules.JoinRule == spec.Restricted
+ // If the join rule isn't "restricted" or "knock_restricted" then there's nothing more to do.
+ res.Restricted = joinRules.JoinRule == spec.Restricted || joinRules.JoinRule == spec.KnockRestricted
if !res.Restricted {
return nil
}
@@ -932,9 +914,9 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
if err != nil {
return fmt.Errorf("r.DB.GetStateEvent: %w", err)
}
- var powerLevels gomatrixserverlib.PowerLevelContent
- if err = json.Unmarshal(powerLevelsEvent.Content(), &powerLevels); err != nil {
- return fmt.Errorf("json.Unmarshal: %w", err)
+ powerLevels, err := powerLevelsEvent.PowerLevels()
+ if err != nil {
+ return fmt.Errorf("unable to get powerlevels: %w", err)
}
// Step through the join rules and see if the user matches any of them.
for _, rule := range joinRules.Allow {