aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/sqlite3/state_block_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/storage/sqlite3/state_block_table.go')
-rw-r--r--roomserver/storage/sqlite3/state_block_table.go242
1 files changed, 64 insertions, 178 deletions
diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go
index 2c544f2b..cfb2a49e 100644
--- a/roomserver/storage/sqlite3/state_block_table.go
+++ b/roomserver/storage/sqlite3/state_block_table.go
@@ -18,6 +18,7 @@ package sqlite3
import (
"context"
"database/sql"
+ "encoding/json"
"fmt"
"sort"
"strings"
@@ -32,228 +33,113 @@ import (
const stateDataSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state_block (
- state_block_nid INTEGER NOT NULL,
- event_type_nid INTEGER NOT NULL,
- event_state_key_nid INTEGER NOT NULL,
- event_nid INTEGER NOT NULL,
- UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
+ -- The state snapshot NID that identifies this snapshot.
+ state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT,
+ -- The hash of the state block, which is used to enforce uniqueness. The hash is
+ -- generated in Dendrite and passed through to the database, as a btree index over
+ -- this column is cheap and fits within the maximum index size.
+ state_block_hash BLOB UNIQUE,
+ -- The event NIDs contained within the state block, encoded as JSON.
+ event_nids TEXT NOT NULL DEFAULT '[]'
);
`
-const insertStateDataSQL = "" +
- "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
- " VALUES ($1, $2, $3, $4)"
-
-const selectNextStateBlockNIDSQL = `
-SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
+// Insert a new state block. If we conflict on the hash column then
+// we must perform an update so that the RETURNING statement returns the
+// ID of the row that we conflicted with, so that we can then refer to
+// the original block.
+const insertStateDataSQL = `
+ INSERT INTO roomserver_state_block (state_block_hash, event_nids)
+ VALUES ($1, $2)
+ ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2
+ RETURNING state_block_nid
`
-// Bulk state lookup by numeric state block ID.
-// Sort by the state_block_nid, event_type_nid, event_state_key_nid
-// This means that all the entries for a given state_block_nid will appear
-// together in the list and those entries will sorted by event_type_nid
-// and event_state_key_nid. This property makes it easier to merge two
-// state data blocks together.
const bulkSelectStateBlockEntriesSQL = "" +
- "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
- " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
- " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
-
-// Bulk state lookup by numeric state block ID.
-// Filters the rows in each block to the requested types and state keys.
-// We would like to restrict to particular type state key pairs but we are
-// restricted by the query language to pull the cross product of a list
-// of types and a list state_keys. So we have to filter the result in the
-// application to restrict it to the list of event types and state keys we
-// actually wanted.
-const bulkSelectFilteredStateBlockEntriesSQL = "" +
- "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
- " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
- " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
- " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
+ "SELECT state_block_nid, event_nids" +
+ " FROM roomserver_state_block WHERE state_block_nid IN ($1)"
type stateBlockStatements struct {
- db *sql.DB
- insertStateDataStmt *sql.Stmt
- selectNextStateBlockNIDStmt *sql.Stmt
- bulkSelectStateBlockEntriesStmt *sql.Stmt
- bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
+ db *sql.DB
+ insertStateDataStmt *sql.Stmt
+ bulkSelectStateBlockEntriesStmt *sql.Stmt
+}
+
+func createStateBlockTable(db *sql.DB) error {
+ _, err := db.Exec(stateDataSchema)
+ return err
}
-func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
+func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{
db: db,
}
- _, err := db.Exec(stateDataSchema)
- if err != nil {
- return nil, err
- }
return s, shared.StatementList{
{&s.insertStateDataStmt, insertStateDataSQL},
- {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
- {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
}.Prepare(db)
}
func (s *stateBlockStatements) BulkInsertStateData(
- ctx context.Context, txn *sql.Tx,
- entries []types.StateEntry,
-) (types.StateBlockNID, error) {
- if len(entries) == 0 {
- return 0, nil
- }
- var stateBlockNID types.StateBlockNID
- err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
+ ctx context.Context,
+ txn *sql.Tx,
+ entries types.StateEntries,
+) (id types.StateBlockNID, err error) {
+ entries = entries[:util.SortAndUnique(entries)]
+ var nids types.EventNIDs
+ for _, e := range entries {
+ nids = append(nids, e.EventNID)
+ }
+ js, err := json.Marshal(nids)
if err != nil {
- return 0, err
- }
- for _, entry := range entries {
- _, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext(
- ctx,
- int64(stateBlockNID),
- int64(entry.EventTypeNID),
- int64(entry.EventStateKeyNID),
- int64(entry.EventNID),
- )
- if err != nil {
- return 0, err
- }
+ return 0, fmt.Errorf("json.Marshal: %w", err)
}
- return stateBlockNID, err
+ err = s.insertStateDataStmt.QueryRowContext(
+ ctx, nids.Hash(), js,
+ ).Scan(&id)
+ return
}
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
- ctx context.Context, stateBlockNIDs []types.StateBlockNID,
-) ([]types.StateEntryList, error) {
- nids := make([]interface{}, len(stateBlockNIDs))
- for k, v := range stateBlockNIDs {
- nids[k] = v
+ ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
+) ([][]types.EventNID, error) {
+ intfs := make([]interface{}, len(stateBlockNIDs))
+ for i := range stateBlockNIDs {
+ intfs[i] = int64(stateBlockNIDs[i])
}
- selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
+ selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(intfs)), 1)
selectStmt, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
- rows, err := selectStmt.QueryContext(ctx, nids...)
+ rows, err := selectStmt.QueryContext(ctx, intfs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
- results := make([]types.StateEntryList, len(stateBlockNIDs))
- // current is a pointer to the StateEntryList to append the state entries to.
- var current *types.StateEntryList
+ results := make([][]types.EventNID, len(stateBlockNIDs))
i := 0
- for rows.Next() {
- var (
- stateBlockNID int64
- eventTypeNID int64
- eventStateKeyNID int64
- eventNID int64
- entry types.StateEntry
- )
- if err := rows.Scan(
- &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
- ); err != nil {
+ for ; rows.Next(); i++ {
+ var stateBlockNID types.StateBlockNID
+ var result json.RawMessage
+ if err = rows.Scan(&stateBlockNID, &result); err != nil {
return nil, err
}
- entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
- entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
- entry.EventNID = types.EventNID(eventNID)
- if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
- // The state entry row is for a different state data block to the current one.
- // So we start appending to the next entry in the list.
- current = &results[i]
- current.StateBlockNID = types.StateBlockNID(stateBlockNID)
- i++
+ r := []types.EventNID{}
+ if err = json.Unmarshal(result, &r); err != nil {
+ return nil, fmt.Errorf("json.Unmarshal: %w", err)
}
- current.StateEntries = append(current.StateEntries, entry)
- }
- if i != len(nids) {
- return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids))
+ results[i] = r
}
- return results, nil
-}
-
-func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
- ctx context.Context,
- stateBlockNIDs []types.StateBlockNID,
- stateKeyTuples []types.StateKeyTuple,
-) ([]types.StateEntryList, error) {
- tuples := stateKeyTupleSorter(stateKeyTuples)
- // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
- sort.Sort(tuples)
-
- eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
- sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1)
- sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
- sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
-
- var params []interface{}
- for _, val := range stateBlockNIDs {
- params = append(params, int64(val))
- }
- for _, val := range eventTypeNIDArray {
- params = append(params, val)
- }
- for _, val := range eventStateKeyNIDArray {
- params = append(params, val)
- }
-
- rows, err := s.db.QueryContext(
- ctx,
- sqlStatement,
- params...,
- )
- if err != nil {
+ if err = rows.Err(); err != nil {
return nil, err
}
- defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed")
-
- var results []types.StateEntryList
- var current types.StateEntryList
- for rows.Next() {
- var (
- stateBlockNID int64
- eventTypeNID int64
- eventStateKeyNID int64
- eventNID int64
- entry types.StateEntry
- )
- if err := rows.Scan(
- &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
- ); err != nil {
- return nil, err
- }
- entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
- entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
- entry.EventNID = types.EventNID(eventNID)
-
- // We can use binary search here because we sorted the tuples earlier
- if !tuples.contains(entry.StateKeyTuple) {
- // The select will return the cross product of types and state keys.
- // So we need to check if type of the entry is in the list.
- continue
- }
-
- if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
- // The state entry row is for a different state data block to the current one.
- // So we append the current entry to the results and start adding to a new one.
- // The first time through the loop current will be empty.
- if current.StateEntries != nil {
- results = append(results, current)
- }
- current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
- }
- current.StateEntries = append(current.StateEntries, entry)
- }
- // Add the last entry to the list if it is not empty.
- if current.StateEntries != nil {
- results = append(results, current)
+ if i != len(stateBlockNIDs) {
+ return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateBlockNIDs))
}
- return results, nil
+ return results, err
}
type stateKeyTupleSorter []types.StateKeyTuple