aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/helpers/helpers.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/internal/helpers/helpers.go')
-rw-r--r--roomserver/internal/helpers/helpers.go42
1 files changed, 38 insertions, 4 deletions
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go
index e4de878e..036c717a 100644
--- a/roomserver/internal/helpers/helpers.go
+++ b/roomserver/internal/helpers/helpers.go
@@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"fmt"
+ "strings"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/auth"
@@ -222,12 +223,45 @@ func CheckServerAllowedToSeeEvent(
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
- return false, err
+ return false, fmt.Errorf("roomState.LoadStateAtEvent: %w", err)
+ }
+
+ // Extract all of the event state key NIDs from the room state.
+ var stateKeyNIDs []types.EventStateKeyNID
+ for _, entry := range stateEntries {
+ stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID)
+ }
+
+ // Then request those state key NIDs from the database.
+ stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs)
+ if err != nil {
+ return false, fmt.Errorf("db.EventStateKeys: %w", err)
+ }
+
+ // If the event state key doesn't match the given servername
+ // then we'll filter it out. This does preserve state keys that
+ // are "" since these will contain history visibility etc.
+ for nid, key := range stateKeys {
+ if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) {
+ delete(stateKeys, nid)
+ }
+ }
+
+ // Now filter through all of the state events for the room.
+ // If the state key NID appears in the list of valid state
+ // keys then we'll add it to the list of filtered entries.
+ var filteredEntries []types.StateEntry
+ for _, entry := range stateEntries {
+ if _, ok := stateKeys[entry.EventStateKeyNID]; ok {
+ filteredEntries = append(filteredEntries, entry)
+ }
+ }
+
+ if len(filteredEntries) == 0 {
+ return false, nil
}
- // TODO: We probably want to make it so that we don't have to pull
- // out all the state if possible.
- stateAtEvent, err := LoadStateEvents(ctx, db, stateEntries)
+ stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries)
if err != nil {
return false, err
}