diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-08-01 14:11:00 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-01 14:11:00 +0100 |
commit | 05c83923e3bf24fa7def55a14ac096cdff3a3882 (patch) | |
tree | 3ee6bca7393990e75b83971a23611f6982998d59 /roomserver/internal/helpers/helpers.go | |
parent | c7f7aec4d07d59120d37d5b16a900f6d608a75c4 (diff) |
Optimise checking other servers allowed to see events (#2596)
* Try optimising checking if server is allowed to see event
* Fix error
* Handle case where snapshot NID is 0
* Fix query
* Update SQL
* Clean up `CheckServerAllowedToSeeEvent`
* Not supported on SQLite
* Maybe placate the unit tests
* Review comments
Diffstat (limited to 'roomserver/internal/helpers/helpers.go')
-rw-r--r-- | roomserver/internal/helpers/helpers.go | 36 |
1 files changed, 26 insertions, 10 deletions
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 2653027e..16a6f615 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -236,13 +236,34 @@ func LoadStateEvents( func CheckServerAllowedToSeeEvent( ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ) (bool, error) { + stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName)) + switch err { + case nil: + // No error, so continue normally + case tables.OptimisationNotSupportedError: + // The database engine didn't support this optimisation, so fall back to using + // the old and slow method + stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, eventID, serverName) + if err != nil { + return false, err + } + default: + // Something else went wrong + return false, err + } + return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil +} + +func slowGetHistoryVisibilityState( + ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, +) ([]*gomatrixserverlib.Event, error) { roomState := state.NewStateResolution(db, info) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return false, nil + return nil, nil } - return false, fmt.Errorf("roomState.LoadStateAtEvent: %w", err) + return nil, fmt.Errorf("roomState.LoadStateAtEvent: %w", err) } // Extract all of the event state key NIDs from the room state. @@ -254,7 +275,7 @@ func CheckServerAllowedToSeeEvent( // Then request those state key NIDs from the database. stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs) if err != nil { - return false, fmt.Errorf("db.EventStateKeys: %w", err) + return nil, fmt.Errorf("db.EventStateKeys: %w", err) } // If the event state key doesn't match the given servername @@ -277,15 +298,10 @@ func CheckServerAllowedToSeeEvent( } if len(filteredEntries) == 0 { - return false, nil - } - - stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries) - if err != nil { - return false, err + return nil, nil } - return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil + return LoadStateEvents(ctx, db, filteredEntries) } // TODO: Remove this when we have tests to assert correctness of this function |