diff options
Diffstat (limited to 'roomserver/storage/sqlite3/state_block_table.go')
-rw-r--r-- | roomserver/storage/sqlite3/state_block_table.go | 242 |
1 files changed, 64 insertions, 178 deletions
diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 2c544f2b..cfb2a49e 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -18,6 +18,7 @@ package sqlite3 import ( "context" "database/sql" + "encoding/json" "fmt" "sort" "strings" @@ -32,228 +33,113 @@ import ( const stateDataSchema = ` CREATE TABLE IF NOT EXISTS roomserver_state_block ( - state_block_nid INTEGER NOT NULL, - event_type_nid INTEGER NOT NULL, - event_state_key_nid INTEGER NOT NULL, - event_nid INTEGER NOT NULL, - UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) + -- The state snapshot NID that identifies this snapshot. + state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT, + -- The hash of the state block, which is used to enforce uniqueness. The hash is + -- generated in Dendrite and passed through to the database, as a btree index over + -- this column is cheap and fits within the maximum index size. + state_block_hash BLOB UNIQUE, + -- The event NIDs contained within the state block, encoded as JSON. + event_nids TEXT NOT NULL DEFAULT '[]' ); ` -const insertStateDataSQL = "" + - "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + - " VALUES ($1, $2, $3, $4)" - -const selectNextStateBlockNIDSQL = ` -SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block +// Insert a new state block. If we conflict on the hash column then +// we must perform an update so that the RETURNING statement returns the +// ID of the row that we conflicted with, so that we can then refer to +// the original block. +const insertStateDataSQL = ` + INSERT INTO roomserver_state_block (state_block_hash, event_nids) + VALUES ($1, $2) + ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2 + RETURNING state_block_nid ` -// Bulk state lookup by numeric state block ID. -// Sort by the state_block_nid, event_type_nid, event_state_key_nid -// This means that all the entries for a given state_block_nid will appear -// together in the list and those entries will sorted by event_type_nid -// and event_state_key_nid. This property makes it easier to merge two -// state data blocks together. const bulkSelectStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" - -// Bulk state lookup by numeric state block ID. -// Filters the rows in each block to the requested types and state keys. -// We would like to restrict to particular type state key pairs but we are -// restricted by the query language to pull the cross product of a list -// of types and a list state_keys. So we have to filter the result in the -// application to restrict it to the list of event types and state keys we -// actually wanted. -const bulkSelectFilteredStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + "SELECT state_block_nid, event_nids" + + " FROM roomserver_state_block WHERE state_block_nid IN ($1)" type stateBlockStatements struct { - db *sql.DB - insertStateDataStmt *sql.Stmt - selectNextStateBlockNIDStmt *sql.Stmt - bulkSelectStateBlockEntriesStmt *sql.Stmt - bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt + db *sql.DB + insertStateDataStmt *sql.Stmt + bulkSelectStateBlockEntriesStmt *sql.Stmt +} + +func createStateBlockTable(db *sql.DB) error { + _, err := db.Exec(stateDataSchema) + return err } -func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { +func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { s := &stateBlockStatements{ db: db, } - _, err := db.Exec(stateDataSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, - {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, - {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, }.Prepare(db) } func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, txn *sql.Tx, - entries []types.StateEntry, -) (types.StateBlockNID, error) { - if len(entries) == 0 { - return 0, nil - } - var stateBlockNID types.StateBlockNID - err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + ctx context.Context, + txn *sql.Tx, + entries types.StateEntries, +) (id types.StateBlockNID, err error) { + entries = entries[:util.SortAndUnique(entries)] + var nids types.EventNIDs + for _, e := range entries { + nids = append(nids, e.EventNID) + } + js, err := json.Marshal(nids) if err != nil { - return 0, err - } - for _, entry := range entries { - _, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) - if err != nil { - return 0, err - } + return 0, fmt.Errorf("json.Marshal: %w", err) } - return stateBlockNID, err + err = s.insertStateDataStmt.QueryRowContext( + ctx, nids.Hash(), js, + ).Scan(&id) + return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - nids := make([]interface{}, len(stateBlockNIDs)) - for k, v := range stateBlockNIDs { - nids[k] = v + ctx context.Context, stateBlockNIDs types.StateBlockNIDs, +) ([][]types.EventNID, error) { + intfs := make([]interface{}, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + intfs[i] = int64(stateBlockNIDs[i]) } - selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(intfs)), 1) selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } - rows, err := selectStmt.QueryContext(ctx, nids...) + rows, err := selectStmt.QueryContext(ctx, intfs...) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") - results := make([]types.StateEntryList, len(stateBlockNIDs)) - // current is a pointer to the StateEntryList to append the state entries to. - var current *types.StateEntryList + results := make([][]types.EventNID, len(stateBlockNIDs)) i := 0 - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { + for ; rows.Next(); i++ { + var stateBlockNID types.StateBlockNID + var result json.RawMessage + if err = rows.Scan(&stateBlockNID, &result); err != nil { return nil, err } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we start appending to the next entry in the list. - current = &results[i] - current.StateBlockNID = types.StateBlockNID(stateBlockNID) - i++ + r := []types.EventNID{} + if err = json.Unmarshal(result, &r); err != nil { + return nil, fmt.Errorf("json.Unmarshal: %w", err) } - current.StateEntries = append(current.StateEntries, entry) - } - if i != len(nids) { - return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids)) + results[i] = r } - return results, nil -} - -func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) - // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. - sort.Sort(tuples) - - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) - - var params []interface{} - for _, val := range stateBlockNIDs { - params = append(params, int64(val)) - } - for _, val := range eventTypeNIDArray { - params = append(params, val) - } - for _, val := range eventStateKeyNIDArray { - params = append(params, val) - } - - rows, err := s.db.QueryContext( - ctx, - sqlStatement, - params..., - ) - if err != nil { + if err = rows.Err(); err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") - - var results []types.StateEntryList - var current types.StateEntryList - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { - return nil, err - } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - - // We can use binary search here because we sorted the tuples earlier - if !tuples.contains(entry.StateKeyTuple) { - // The select will return the cross product of types and state keys. - // So we need to check if type of the entry is in the list. - continue - } - - if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we append the current entry to the results and start adding to a new one. - // The first time through the loop current will be empty. - if current.StateEntries != nil { - results = append(results, current) - } - current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} - } - current.StateEntries = append(current.StateEntries, entry) - } - // Add the last entry to the list if it is not empty. - if current.StateEntries != nil { - results = append(results, current) + if i != len(stateBlockNIDs) { + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateBlockNIDs)) } - return results, nil + return results, err } type stateKeyTupleSorter []types.StateKeyTuple |