diff options
author | devonh <devon.dmytro@gmail.com> | 2023-06-14 14:23:46 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-14 14:23:46 +0000 |
commit | e4665979bfbe006368d55189f074e456fe19b198 (patch) | |
tree | e909d694a022478d0dbe3cc58ee8a2dc289bc969 /roomserver/internal/helpers/helpers.go | |
parent | 7a2e325d1014d76188b47a011730a42443f3c174 (diff) |
Merge SenderID & Per Room User Key work (#3109)
Diffstat (limited to 'roomserver/internal/helpers/helpers.go')
-rw-r--r-- | roomserver/internal/helpers/helpers.go | 38 |
1 files changed, 21 insertions, 17 deletions
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 263cb9f8..febabf41 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -68,7 +68,7 @@ func UpdateToInviteMembership( // memberships. If the servername is not supplied then the local server will be // checked instead using a faster code path. // TODO: This should probably be replaced by an API call. -func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName spec.ServerName, roomID string) (bool, error) { +func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, serverName spec.ServerName, roomID string) (bool, error) { info, err := db.RoomInfo(ctx, roomID) if err != nil { return false, err @@ -94,7 +94,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam for i := range events { gmslEvents[i] = events[i].PDU } - return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil + return auth.IsAnyUserOnServerWithMembership(ctx, querier, serverName, gmslEvents, spec.Join), nil } func IsInvitePending( @@ -211,8 +211,8 @@ func GetMembershipsAtState( return events, nil } -func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { - roomState := state.NewStateResolution(db, info) +func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID, querier api.QuerySenderIDAPI) ([]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info, querier) // Lookup the event NID eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) if err != nil { @@ -229,8 +229,8 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room return roomState.LoadCombinedStateAfterEvents(ctx, prevState) } -func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { - roomState := state.NewStateResolution(db, info) +func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID, querier api.QuerySenderIDAPI) (map[string][]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info, querier) // Fetch the state as it was when this event was fired return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID) } @@ -264,7 +264,7 @@ func LoadStateEvents( } func CheckServerAllowedToSeeEvent( - ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, + ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, querier api.QuerySenderIDAPI, ) (bool, error) { stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName)) switch err { @@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent( case tables.OptimisationNotSupportedError: // The database engine didn't support this optimisation, so fall back to using // the old and slow method - stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName) + stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName, querier) if err != nil { return false, err } @@ -288,13 +288,13 @@ func CheckServerAllowedToSeeEvent( return false, err } } - return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil + return auth.IsServerAllowed(ctx, querier, serverName, isServerInRoom, stateAtEvent), nil } func slowGetHistoryVisibilityState( - ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, + ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, querier api.QuerySenderIDAPI, ) ([]gomatrixserverlib.PDU, error) { - roomState := state.NewStateResolution(db, info) + roomState := state.NewStateResolution(db, info, querier) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -318,9 +318,13 @@ func slowGetHistoryVisibilityState( // If the event state key doesn't match the given servername // then we'll filter it out. This does preserve state keys that // are "" since these will contain history visibility etc. + validRoomID, err := spec.NewRoomID(roomID) + if err != nil { + return nil, err + } for nid, key := range stateKeys { if key != "" { - userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key)) + userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(key)) if err == nil && userID != nil { if userID.Domain() != serverName { delete(stateKeys, nid) @@ -349,7 +353,7 @@ func slowGetHistoryVisibilityState( // TODO: Remove this when we have tests to assert correctness of this function func ScanEventTree( ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int, - serverName spec.ServerName, + serverName spec.ServerName, querier api.QuerySenderIDAPI, ) ([]types.EventNID, map[string]struct{}, error) { var resultNIDs []types.EventNID var err error @@ -392,7 +396,7 @@ BFSLoop: // It's nasty that we have to extract the room ID from an event, but many federation requests // only talk in event IDs, no room IDs at all (!!!) ev := events[0] - isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID()) + isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID()) if err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") } @@ -415,7 +419,7 @@ BFSLoop: // hasn't been seen before. if !visited[pre] { visited[pre] = true - allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom) + allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom, querier) if err != nil { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( "Error checking if allowed to see event", @@ -444,7 +448,7 @@ BFSLoop: } func QueryLatestEventsAndState( - ctx context.Context, db storage.Database, + ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { @@ -457,7 +461,7 @@ func QueryLatestEventsAndState( return nil } - roomState := state.NewStateResolution(db, roomInfo) + roomState := state.NewStateResolution(db, roomInfo, querier) response.RoomExists = true response.RoomVersion = roomInfo.RoomVersion |