aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/postgres/state_snapshot_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/storage/postgres/state_snapshot_table.go')
-rw-r--r--roomserver/storage/postgres/state_snapshot_table.go69
1 files changed, 64 insertions, 5 deletions
diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go
index a00c026f..0e83cfc2 100644
--- a/roomserver/storage/postgres/state_snapshot_table.go
+++ b/roomserver/storage/postgres/state_snapshot_table.go
@@ -21,10 +21,10 @@ import (
"fmt"
"github.com/lib/pq"
+ "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
@@ -99,10 +99,26 @@ const bulkSelectStateForHistoryVisibilitySQL = `
AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2);
`
+// bulkSelectMembershipForHistoryVisibilitySQL is an optimization to get membership events for a specific user for defined set of events.
+// Returns the event_id of the event we want the membership event for, the event_id of the membership event and the membership event JSON.
+const bulkSelectMembershipForHistoryVisibilitySQL = `
+SELECT re.event_id, re2.event_id, rej.event_json
+FROM roomserver_events re
+LEFT JOIN roomserver_state_snapshots rss on re.state_snapshot_nid = rss.state_snapshot_nid
+CROSS JOIN unnest(rss.state_block_nids) AS blocks(block_nid)
+LEFT JOIN roomserver_state_block rsb ON rsb.state_block_nid = blocks.block_nid
+CROSS JOIN unnest(rsb.event_nids) AS rsb2(event_nid)
+JOIN roomserver_events re2 ON re2.room_nid = $3 AND re2.event_type_nid = 5 AND re2.event_nid = rsb2.event_nid AND re2.event_state_key_nid = $1
+LEFT JOIN roomserver_event_json rej ON rej.event_nid = re2.event_nid
+WHERE re.event_id = ANY($2)
+
+`
+
type stateSnapshotStatements struct {
- insertStateStmt *sql.Stmt
- bulkSelectStateBlockNIDsStmt *sql.Stmt
- bulkSelectStateForHistoryVisibilityStmt *sql.Stmt
+ insertStateStmt *sql.Stmt
+ bulkSelectStateBlockNIDsStmt *sql.Stmt
+ bulkSelectStateForHistoryVisibilityStmt *sql.Stmt
+ bulktSelectMembershipForHistoryVisibilityStmt *sql.Stmt
}
func CreateStateSnapshotTable(db *sql.DB) error {
@@ -110,13 +126,14 @@ func CreateStateSnapshotTable(db *sql.DB) error {
return err
}
-func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
+func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) {
s := &stateSnapshotStatements{}
return s, sqlutil.StatementList{
{&s.insertStateStmt, insertStateSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
{&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL},
+ {&s.bulktSelectMembershipForHistoryVisibilityStmt, bulkSelectMembershipForHistoryVisibilitySQL},
}.Prepare(db)
}
@@ -185,3 +202,45 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility(
}
return results, rows.Err()
}
+
+func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(
+ ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string,
+) (map[string]*gomatrixserverlib.HeaderedEvent, error) {
+ stmt := sqlutil.TxStmt(txn, s.bulktSelectMembershipForHistoryVisibilityStmt)
+ rows, err := stmt.QueryContext(ctx, userNID, pq.Array(eventIDs), roomInfo.RoomNID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ result := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs))
+ var evJson []byte
+ var eventID string
+ var membershipEventID string
+
+ knownEvents := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs))
+
+ for rows.Next() {
+ if err = rows.Scan(&eventID, &membershipEventID, &evJson); err != nil {
+ return nil, err
+ }
+ if len(evJson) == 0 {
+ result[eventID] = &gomatrixserverlib.HeaderedEvent{}
+ continue
+ }
+ // If we already know this event, don't try to marshal the json again
+ if ev, ok := knownEvents[membershipEventID]; ok {
+ result[eventID] = ev
+ continue
+ }
+ event, err := gomatrixserverlib.NewEventFromTrustedJSON(evJson, false, roomInfo.RoomVersion)
+ if err != nil {
+ result[eventID] = &gomatrixserverlib.HeaderedEvent{}
+ // not fatal
+ continue
+ }
+ he := event.Headered(roomInfo.RoomVersion)
+ result[eventID] = he
+ knownEvents[membershipEventID] = he
+ }
+ return result, rows.Err()
+}