diff options
Diffstat (limited to 'roomserver/storage/postgres/state_snapshot_table.go')
-rw-r--r-- | roomserver/storage/postgres/state_snapshot_table.go | 69 |
1 files changed, 64 insertions, 5 deletions
diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index a00c026f..0e83cfc2 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -21,10 +21,10 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -99,10 +99,26 @@ const bulkSelectStateForHistoryVisibilitySQL = ` AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2); ` +// bulkSelectMembershipForHistoryVisibilitySQL is an optimization to get membership events for a specific user for defined set of events. +// Returns the event_id of the event we want the membership event for, the event_id of the membership event and the membership event JSON. +const bulkSelectMembershipForHistoryVisibilitySQL = ` +SELECT re.event_id, re2.event_id, rej.event_json +FROM roomserver_events re +LEFT JOIN roomserver_state_snapshots rss on re.state_snapshot_nid = rss.state_snapshot_nid +CROSS JOIN unnest(rss.state_block_nids) AS blocks(block_nid) +LEFT JOIN roomserver_state_block rsb ON rsb.state_block_nid = blocks.block_nid +CROSS JOIN unnest(rsb.event_nids) AS rsb2(event_nid) +JOIN roomserver_events re2 ON re2.room_nid = $3 AND re2.event_type_nid = 5 AND re2.event_nid = rsb2.event_nid AND re2.event_state_key_nid = $1 +LEFT JOIN roomserver_event_json rej ON rej.event_nid = re2.event_nid +WHERE re.event_id = ANY($2) + +` + type stateSnapshotStatements struct { - insertStateStmt *sql.Stmt - bulkSelectStateBlockNIDsStmt *sql.Stmt - bulkSelectStateForHistoryVisibilityStmt *sql.Stmt + insertStateStmt *sql.Stmt + bulkSelectStateBlockNIDsStmt *sql.Stmt + bulkSelectStateForHistoryVisibilityStmt *sql.Stmt + bulktSelectMembershipForHistoryVisibilityStmt *sql.Stmt } func CreateStateSnapshotTable(db *sql.DB) error { @@ -110,13 +126,14 @@ func CreateStateSnapshotTable(db *sql.DB) error { return err } -func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) { s := &stateSnapshotStatements{} return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, {&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL}, + {&s.bulktSelectMembershipForHistoryVisibilityStmt, bulkSelectMembershipForHistoryVisibilitySQL}, }.Prepare(db) } @@ -185,3 +202,45 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( } return results, rows.Err() } + +func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility( + ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, +) (map[string]*gomatrixserverlib.HeaderedEvent, error) { + stmt := sqlutil.TxStmt(txn, s.bulktSelectMembershipForHistoryVisibilityStmt) + rows, err := stmt.QueryContext(ctx, userNID, pq.Array(eventIDs), roomInfo.RoomNID) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + result := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs)) + var evJson []byte + var eventID string + var membershipEventID string + + knownEvents := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs)) + + for rows.Next() { + if err = rows.Scan(&eventID, &membershipEventID, &evJson); err != nil { + return nil, err + } + if len(evJson) == 0 { + result[eventID] = &gomatrixserverlib.HeaderedEvent{} + continue + } + // If we already know this event, don't try to marshal the json again + if ev, ok := knownEvents[membershipEventID]; ok { + result[eventID] = ev + continue + } + event, err := gomatrixserverlib.NewEventFromTrustedJSON(evJson, false, roomInfo.RoomVersion) + if err != nil { + result[eventID] = &gomatrixserverlib.HeaderedEvent{} + // not fatal + continue + } + he := event.Headered(roomInfo.RoomVersion) + result[eventID] = he + knownEvents[membershipEventID] = he + } + return result, rows.Err() +} |