aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2021-04-26 13:25:57 +0100
committerGitHub <noreply@github.com>2021-04-26 13:25:57 +0100
commit5ce1fe80dea8b8cfca8712e8d584deb995bbddcc (patch)
tree1307a1edf73abf68cebd4601efec1e467dac964c
parentd6e9b7b307ff0d7541046ec33890d49239c7a6ca (diff)
State storage refactor (#1839)
* Hash-deduplicated state storage (and migrations) for PostgreSQL and SQLite * Refactor droomserver database setup for migrations * Fix conflict statements * Update migration names * Set a boundary for old to new block/snapshot IDs so we don't rewrite them more than once accidentally * Create sequence if not exists * Fix boundary queries * Fix boundary queries * Use Query * Break out queries a bit * More sequence tweaks * Query parameters are not playing the game * Injection escaping may not work for CREATE SEQUENCE after all * Fix snapshot sequence name * Use boundaried IDs in SQLite too * Use IFNULL for SQLite * Use COALESCE in PostgreSQL * Review comments @Kegsay
-rw-r--r--go.mod2
-rw-r--r--go.sum4
-rw-r--r--roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go1
-rw-r--r--roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go223
-rw-r--r--roomserver/storage/postgres/event_json_table.go12
-rw-r--r--roomserver/storage/postgres/event_state_keys_table.go12
-rw-r--r--roomserver/storage/postgres/event_types_table.go11
-rw-r--r--roomserver/storage/postgres/events_table.go60
-rw-r--r--roomserver/storage/postgres/invite_table.go11
-rw-r--r--roomserver/storage/postgres/membership_table.go16
-rw-r--r--roomserver/storage/postgres/previous_events_table.go11
-rw-r--r--roomserver/storage/postgres/published_table.go12
-rw-r--r--roomserver/storage/postgres/redactions_table.go11
-rw-r--r--roomserver/storage/postgres/room_aliases_table.go12
-rw-r--r--roomserver/storage/postgres/rooms_table.go12
-rw-r--r--roomserver/storage/postgres/state_block_table.go214
-rw-r--r--roomserver/storage/postgres/state_snapshot_table.go52
-rw-r--r--roomserver/storage/postgres/storage.go94
-rw-r--r--roomserver/storage/postgres/transactions_table.go11
-rw-r--r--roomserver/storage/shared/storage.go58
-rw-r--r--roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go1
-rw-r--r--roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go168
-rw-r--r--roomserver/storage/sqlite3/event_json_table.go12
-rw-r--r--roomserver/storage/sqlite3/event_state_keys_table.go12
-rw-r--r--roomserver/storage/sqlite3/event_types_table.go11
-rw-r--r--roomserver/storage/sqlite3/events_table.go72
-rw-r--r--roomserver/storage/sqlite3/invite_table.go11
-rw-r--r--roomserver/storage/sqlite3/membership_table.go12
-rw-r--r--roomserver/storage/sqlite3/previous_events_table.go11
-rw-r--r--roomserver/storage/sqlite3/published_table.go12
-rw-r--r--roomserver/storage/sqlite3/redactions_table.go11
-rw-r--r--roomserver/storage/sqlite3/room_aliases_table.go12
-rw-r--r--roomserver/storage/sqlite3/rooms_table.go12
-rw-r--r--roomserver/storage/sqlite3/state_block_table.go242
-rw-r--r--roomserver/storage/sqlite3/state_snapshot_table.go42
-rw-r--r--roomserver/storage/sqlite3/storage.go90
-rw-r--r--roomserver/storage/sqlite3/transactions_table.go11
-rw-r--r--roomserver/storage/tables/interface.go9
-rw-r--r--roomserver/types/types.go40
39 files changed, 1076 insertions, 554 deletions
diff --git a/go.mod b/go.mod
index a3d80f1b..16c25ead 100644
--- a/go.mod
+++ b/go.mod
@@ -25,7 +25,7 @@ require (
github.com/matrix-org/gomatrixserverlib v0.0.0-20210302161955-6142fe3f8c2c
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
- github.com/mattn/go-sqlite3 v1.14.6
+ github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6
github.com/opentracing/opentracing-go v1.2.0
diff --git a/go.sum b/go.sum
index 90b5527c..f245a270 100644
--- a/go.sum
+++ b/go.sum
@@ -684,8 +684,8 @@ github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcME
github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-sqlite3 v1.14.2/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus=
-github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
-github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
+github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb h1:ax2vG2unlxsjwS7PMRo4FECIfAdQLowd6ejWYwPQhBo=
+github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
diff --git a/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go b/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go
index 733f0fa1..f3bd8632 100644
--- a/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go
+++ b/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go
@@ -24,6 +24,7 @@ import (
func LoadFromGoose() {
goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
+ goose.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
}
func LoadAddForgottenColumn(m *sqlutil.Migrations) {
diff --git a/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go
new file mode 100644
index 00000000..84da9614
--- /dev/null
+++ b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go
@@ -0,0 +1,223 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "database/sql"
+ "fmt"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
+)
+
+type stateSnapshotData struct {
+ StateSnapshotNID types.StateSnapshotNID
+ RoomNID types.RoomNID
+}
+
+type stateBlockData struct {
+ stateSnapshotData
+ StateBlockNID types.StateBlockNID
+ EventNIDs types.EventNIDs
+}
+
+func LoadStateBlocksRefactor(m *sqlutil.Migrations) {
+ m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
+}
+
+// nolint:gocyclo
+func UpStateBlocksRefactor(tx *sql.Tx) error {
+ logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!")
+ defer logrus.Warn("State storage upgrade complete")
+
+ var snapshotcount int
+ var maxsnapshotid int
+ var maxblockid int
+ if err := tx.QueryRow(`SELECT COUNT(DISTINCT state_snapshot_nid) FROM roomserver_state_snapshots;`).Scan(&snapshotcount); err != nil {
+ return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err)
+ }
+ if err := tx.QueryRow(`SELECT COALESCE(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil {
+ return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err)
+ }
+ if err := tx.QueryRow(`SELECT COALESCE(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil {
+ return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err)
+ }
+ maxsnapshotid++
+ maxblockid++
+
+ if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil {
+ return fmt.Errorf("tx.Exec: %w", err)
+ }
+ if _, err := tx.Exec(`ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil {
+ return fmt.Errorf("tx.Exec: %w", err)
+ }
+ // We create new sequences starting with the maximum state snapshot and block NIDs.
+ // This means that all newly created snapshots and blocks by the migration will have
+ // NIDs higher than these values, so that when we come to update the references to
+ // these NIDs using UPDATE statements, we can guarantee we are only ever updating old
+ // values and not accidentally overwriting new ones.
+ if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE roomserver_state_block_nid_sequence START WITH %d;`, maxblockid)); err != nil {
+ return fmt.Errorf("tx.Exec: %w", err)
+ }
+ if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE roomserver_state_snapshot_nid_sequence START WITH %d;`, maxsnapshotid)); err != nil {
+ return fmt.Errorf("tx.Exec: %w", err)
+ }
+ _, err := tx.Exec(`
+ CREATE TABLE IF NOT EXISTS roomserver_state_block (
+ state_block_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_block_nid_sequence'),
+ state_block_hash BYTEA UNIQUE,
+ event_nids bigint[] NOT NULL
+ );
+ `)
+ if err != nil {
+ return fmt.Errorf("tx.Exec (create blocks table): %w", err)
+ }
+ _, err = tx.Exec(`
+ CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
+ state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_sequence'),
+ state_snapshot_hash BYTEA UNIQUE,
+ room_nid bigint NOT NULL,
+ state_block_nids bigint[] NOT NULL
+ );
+ `)
+ if err != nil {
+ return fmt.Errorf("tx.Exec (create snapshots table): %w", err)
+ }
+ logrus.Warn("New tables created...")
+
+ batchsize := 100
+ for batchoffset := 0; batchoffset < snapshotcount; batchoffset += batchsize {
+ var snapshotrows *sql.Rows
+ snapshotrows, err = tx.Query(`
+ SELECT
+ state_snapshot_nid,
+ room_nid,
+ state_block_nid,
+ ARRAY_AGG(event_nid) AS event_nids
+ FROM (
+ SELECT
+ _roomserver_state_snapshots.state_snapshot_nid,
+ _roomserver_state_snapshots.room_nid,
+ _roomserver_state_block.state_block_nid,
+ _roomserver_state_block.event_nid
+ FROM
+ _roomserver_state_snapshots
+ JOIN _roomserver_state_block ON _roomserver_state_block.state_block_nid = ANY (_roomserver_state_snapshots.state_block_nids)
+ WHERE
+ _roomserver_state_snapshots.state_snapshot_nid = ANY ( SELECT DISTINCT
+ _roomserver_state_snapshots.state_snapshot_nid
+ FROM
+ _roomserver_state_snapshots
+ LIMIT $1 OFFSET $2)) AS _roomserver_state_block
+ GROUP BY
+ state_snapshot_nid,
+ room_nid,
+ state_block_nid;
+ `, batchsize, batchoffset)
+ if err != nil {
+ return fmt.Errorf("tx.Query: %w", err)
+ }
+
+ logrus.Warnf("Rewriting snapshots %d-%d of %d...", batchoffset, batchoffset+batchsize, snapshotcount)
+ var snapshots []stateBlockData
+
+ for snapshotrows.Next() {
+ var snapshot stateBlockData
+ var eventsarray pq.Int64Array
+ if err = snapshotrows.Scan(&snapshot.StateSnapshotNID, &snapshot.RoomNID, &snapshot.StateBlockNID, &eventsarray); err != nil {
+ return fmt.Errorf("rows.Scan: %w", err)
+ }
+ for _, e := range eventsarray {
+ snapshot.EventNIDs = append(snapshot.EventNIDs, types.EventNID(e))
+ }
+ snapshot.EventNIDs = snapshot.EventNIDs[:util.SortAndUnique(snapshot.EventNIDs)]
+ snapshots = append(snapshots, snapshot)
+ }
+
+ if err = snapshotrows.Close(); err != nil {
+ return fmt.Errorf("snapshots.Close: %w", err)
+ }
+
+ newsnapshots := map[stateSnapshotData]types.StateBlockNIDs{}
+
+ for _, snapshot := range snapshots {
+ var eventsarray pq.Int64Array
+ for _, e := range snapshot.EventNIDs {
+ eventsarray = append(eventsarray, int64(e))
+ }
+
+ var blocknid types.StateBlockNID
+ err = tx.QueryRow(`
+ INSERT INTO roomserver_state_block (state_block_hash, event_nids)
+ VALUES ($1, $2)
+ ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2
+ RETURNING state_block_nid
+ `, snapshot.EventNIDs.Hash(), eventsarray).Scan(&blocknid)
+ if err != nil {
+ return fmt.Errorf("tx.QueryRow.Scan (insert new block with %d events): %w", len(eventsarray), err)
+ }
+ index := stateSnapshotData{snapshot.StateSnapshotNID, snapshot.RoomNID}
+ newsnapshots[index] = append(newsnapshots[index], blocknid)
+ }
+
+ for snapshotdata, newblocks := range newsnapshots {
+ var newblocksarray pq.Int64Array
+ for _, b := range newblocks {
+ newblocksarray = append(newblocksarray, int64(b))
+ }
+
+ var newNID types.StateSnapshotNID
+ err = tx.QueryRow(`
+ INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids)
+ VALUES ($1, $2, $3)
+ ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2
+ RETURNING state_snapshot_nid
+ `, newblocks.Hash(), snapshotdata.RoomNID, newblocksarray).Scan(&newNID)
+ if err != nil {
+ return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err)
+ }
+
+ if _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil {
+ return fmt.Errorf("tx.Exec (update events): %w", err)
+ }
+
+ if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil {
+ return fmt.Errorf("tx.Exec (update rooms): %w", err)
+ }
+ }
+ }
+
+ if _, err = tx.Exec(`
+ DROP TABLE _roomserver_state_snapshots;
+ DROP SEQUENCE roomserver_state_snapshot_nid_seq;
+ `); err != nil {
+ return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err)
+ }
+ if _, err = tx.Exec(`
+ DROP TABLE _roomserver_state_block;
+ DROP SEQUENCE roomserver_state_block_nid_seq;
+ `); err != nil {
+ return fmt.Errorf("tx.Exec (delete old block table): %w", err)
+ }
+
+ return nil
+}
+
+func DownStateBlocksRefactor(tx *sql.Tx) error {
+ panic("Downgrading state storage is not supported")
+}
diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go
index 8f11d1d8..e0976b12 100644
--- a/roomserver/storage/postgres/event_json_table.go
+++ b/roomserver/storage/postgres/event_json_table.go
@@ -59,12 +59,14 @@ type eventJSONStatements struct {
bulkSelectEventJSONStmt *sql.Stmt
}
-func NewPostgresEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
- s := &eventJSONStatements{}
+func createEventJSONTable(db *sql.DB) error {
_, err := db.Exec(eventJSONSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
+ s := &eventJSONStatements{}
+
return s, shared.StatementList{
{&s.insertEventJSONStmt, insertEventJSONSQL},
{&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL},
diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go
index 500ff20e..61682356 100644
--- a/roomserver/storage/postgres/event_state_keys_table.go
+++ b/roomserver/storage/postgres/event_state_keys_table.go
@@ -77,12 +77,14 @@ type eventStateKeyStatements struct {
bulkSelectEventStateKeyStmt *sql.Stmt
}
-func NewPostgresEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
- s := &eventStateKeyStatements{}
+func createEventStateKeysTable(db *sql.DB) error {
_, err := db.Exec(eventStateKeysSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func prepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
+ s := &eventStateKeyStatements{}
+
return s, shared.StatementList{
{&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
{&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go
index 02d6ad07..f4257850 100644
--- a/roomserver/storage/postgres/event_types_table.go
+++ b/roomserver/storage/postgres/event_types_table.go
@@ -100,12 +100,13 @@ type eventTypeStatements struct {
bulkSelectEventTypeNIDStmt *sql.Stmt
}
-func NewPostgresEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
- s := &eventTypeStatements{}
+func createEventTypesTable(db *sql.DB) error {
_, err := db.Exec(eventTypesSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
+ s := &eventTypeStatements{}
return s, shared.StatementList{
{&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go
index 0cf0bd22..88c82083 100644
--- a/roomserver/storage/postgres/events_table.go
+++ b/roomserver/storage/postgres/events_table.go
@@ -19,6 +19,7 @@ import (
"context"
"database/sql"
"fmt"
+ "sort"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
@@ -88,6 +89,16 @@ const bulkSelectStateEventByIDSQL = "" +
" WHERE event_id = ANY($1)" +
" ORDER BY event_type_nid, event_state_key_nid ASC"
+// Bulk look up of events by event NID, optionally filtering by the event type
+// or event state key NIDs if provided. (The CARDINALITY check will return true
+// if the provided arrays are empty, ergo no filtering).
+const bulkSelectStateEventByNIDSQL = "" +
+ "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
+ " WHERE event_nid = ANY($1)" +
+ " AND (CARDINALITY($2::bigint[]) = 0 OR event_type_nid = ANY($2))" +
+ " AND (CARDINALITY($3::bigint[]) = 0 OR event_state_key_nid = ANY($3))" +
+ " ORDER BY event_type_nid, event_state_key_nid ASC"
+
const bulkSelectStateAtEventByIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
" WHERE event_id = ANY($1)"
@@ -127,6 +138,7 @@ type eventStatements struct {
insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt
+ bulkSelectStateEventByNIDStmt *sql.Stmt
bulkSelectStateAtEventByIDStmt *sql.Stmt
updateEventStateStmt *sql.Stmt
selectEventSentToOutputStmt *sql.Stmt
@@ -140,17 +152,19 @@ type eventStatements struct {
selectRoomNIDsForEventNIDsStmt *sql.Stmt
}
-func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
- s := &eventStatements{}
+func createEventsTable(db *sql.DB) error {
_, err := db.Exec(eventsSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func prepareEventsTable(db *sql.DB) (tables.Events, error) {
+ s := &eventStatements{}
return s, shared.StatementList{
{&s.insertEventStmt, insertEventSQL},
{&s.selectEventStmt, selectEventSQL},
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
+ {&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL},
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
{&s.updateEventStateStmt, updateEventStateSQL},
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
@@ -238,6 +252,42 @@ func (s *eventStatements) BulkSelectStateEventByID(
return results, nil
}
+// bulkSelectStateEventByNID lookups a list of state events by event NID.
+// If any of the requested events are missing from the database it returns a types.MissingEventError
+func (s *eventStatements) BulkSelectStateEventByNID(
+ ctx context.Context, eventNIDs []types.EventNID,
+ stateKeyTuples []types.StateKeyTuple,
+) ([]types.StateEntry, error) {
+ tuples := stateKeyTupleSorter(stateKeyTuples)
+ sort.Sort(tuples)
+ eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
+ rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed")
+ // We know that we will only get as many results as event IDs
+ // because of the unique constraint on event IDs.
+ // So we can allocate an array of the correct size now.
+ // We might get fewer results than IDs so we adjust the length of the slice before returning it.
+ results := make([]types.StateEntry, len(eventNIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ result := &results[i]
+ if err = rows.Scan(
+ &result.EventTypeNID,
+ &result.EventStateKeyNID,
+ &result.EventNID,
+ ); err != nil {
+ return nil, err
+ }
+ }
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
+ return results[:i], nil
+}
+
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go
index bb719516..0a2183e2 100644
--- a/roomserver/storage/postgres/invite_table.go
+++ b/roomserver/storage/postgres/invite_table.go
@@ -82,12 +82,13 @@ type inviteStatements struct {
updateInviteRetiredStmt *sql.Stmt
}
-func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) {
- s := &inviteStatements{}
+func createInvitesTable(db *sql.DB) error {
_, err := db.Exec(inviteSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
+ s := &inviteStatements{}
return s, shared.StatementList{
{&s.insertInviteEventStmt, insertInviteEventSQL},
diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go
index e392a4fb..3466da6d 100644
--- a/roomserver/storage/postgres/membership_table.go
+++ b/roomserver/storage/postgres/membership_table.go
@@ -139,12 +139,13 @@ type membershipStatements struct {
updateMembershipForgetRoomStmt *sql.Stmt
}
-func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
- s := &membershipStatements{}
+func createMembershipTable(db *sql.DB) error {
_, err := db.Exec(membershipSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
+ s := &membershipStatements{}
return s, shared.StatementList{
{&s.insertMembershipStmt, insertMembershipSQL},
@@ -162,11 +163,6 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
}.Prepare(db)
}
-func (s *membershipStatements) execSchema(db *sql.DB) error {
- _, err := db.Exec(membershipSchema)
- return err
-}
-
func (s *membershipStatements) InsertMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go
index 1a4ba673..4a93c3d6 100644
--- a/roomserver/storage/postgres/previous_events_table.go
+++ b/roomserver/storage/postgres/previous_events_table.go
@@ -65,12 +65,13 @@ type previousEventStatements struct {
selectPreviousEventExistsStmt *sql.Stmt
}
-func NewPostgresPreviousEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
- s := &previousEventStatements{}
+func createPrevEventsTable(db *sql.DB) error {
_, err := db.Exec(previousEventSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
+ s := &previousEventStatements{}
return s, shared.StatementList{
{&s.insertPreviousEventStmt, insertPreviousEventSQL},
diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go
index 440ae784..c180576e 100644
--- a/roomserver/storage/postgres/published_table.go
+++ b/roomserver/storage/postgres/published_table.go
@@ -50,12 +50,14 @@ type publishedStatements struct {
selectPublishedStmt *sql.Stmt
}
-func NewPostgresPublishedTable(db *sql.DB) (tables.Published, error) {
- s := &publishedStatements{}
+func createPublishedTable(db *sql.DB) error {
_, err := db.Exec(publishedSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func preparePublishedTable(db *sql.DB) (tables.Published, error) {
+ s := &publishedStatements{}
+
return s, shared.StatementList{
{&s.upsertPublishedStmt, upsertPublishedSQL},
{&s.selectAllPublishedStmt, selectAllPublishedSQL},
diff --git a/roomserver/storage/postgres/redactions_table.go b/roomserver/storage/postgres/redactions_table.go
index 42aba598..3741d5f6 100644
--- a/roomserver/storage/postgres/redactions_table.go
+++ b/roomserver/storage/postgres/redactions_table.go
@@ -60,12 +60,13 @@ type redactionStatements struct {
markRedactionValidatedStmt *sql.Stmt
}
-func NewPostgresRedactionsTable(db *sql.DB) (tables.Redactions, error) {
- s := &redactionStatements{}
+func createRedactionsTable(db *sql.DB) error {
_, err := db.Exec(redactionsSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func prepareRedactionsTable(db *sql.DB) (tables.Redactions, error) {
+ s := &redactionStatements{}
return s, shared.StatementList{
{&s.insertRedactionStmt, insertRedactionSQL},
diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go
index b603a673..c808813e 100644
--- a/roomserver/storage/postgres/room_aliases_table.go
+++ b/roomserver/storage/postgres/room_aliases_table.go
@@ -62,12 +62,14 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt
}
-func NewPostgresRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
- s := &roomAliasesStatements{}
+func createRoomAliasesTable(db *sql.DB) error {
_, err := db.Exec(roomAliasesSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
+ s := &roomAliasesStatements{}
+
return s, shared.StatementList{
{&s.insertRoomAliasStmt, insertRoomAliasSQL},
{&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go
index 637680bd..f2b39fe5 100644
--- a/roomserver/storage/postgres/rooms_table.go
+++ b/roomserver/storage/postgres/rooms_table.go
@@ -96,12 +96,14 @@ type roomStatements struct {
bulkSelectRoomNIDsStmt *sql.Stmt
}
-func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
- s := &roomStatements{}
+func createRoomsTable(db *sql.DB) error {
_, err := db.Exec(roomsSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
+ s := &roomStatements{}
+
return s, shared.StatementList{
{&s.insertRoomNIDStmt, insertRoomNIDSQL},
{&s.selectRoomNIDStmt, selectRoomNIDSQL},
diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go
index d618686f..4523d18b 100644
--- a/roomserver/storage/postgres/state_block_table.go
+++ b/roomserver/storage/postgres/state_block_table.go
@@ -41,141 +41,88 @@ const stateDataSchema = `
-- which in turn makes it easier to merge state data blocks.
CREATE SEQUENCE IF NOT EXISTS roomserver_state_block_nid_seq;
CREATE TABLE IF NOT EXISTS roomserver_state_block (
- -- Local numeric ID for this state data.
- state_block_nid bigint NOT NULL,
- event_type_nid bigint NOT NULL,
- event_state_key_nid bigint NOT NULL,
- event_nid bigint NOT NULL,
- UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
+ -- The state snapshot NID that identifies this snapshot.
+ state_block_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_block_nid_seq'),
+ -- The hash of the state block, which is used to enforce uniqueness. The hash is
+ -- generated in Dendrite and passed through to the database, as a btree index over
+ -- this column is cheap and fits within the maximum index size.
+ state_block_hash BYTEA UNIQUE,
+ -- The event NIDs contained within the state block.
+ event_nids bigint[] NOT NULL
);
`
+// Insert a new state block. If we conflict on the hash column then
+// we must perform an update so that the RETURNING statement returns the
+// ID of the row that we conflicted with, so that we can then refer to
+// the original block.
const insertStateDataSQL = "" +
- "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
- " VALUES ($1, $2, $3, $4)"
+ "INSERT INTO roomserver_state_block (state_block_hash, event_nids)" +
+ " VALUES ($1, $2)" +
+ " ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2" +
+ " RETURNING state_block_nid"
-const selectNextStateBlockNIDSQL = "" +
- "SELECT nextval('roomserver_state_block_nid_seq')"
-
-// Bulk state lookup by numeric state block ID.
-// Sort by the state_block_nid, event_type_nid, event_state_key_nid
-// This means that all the entries for a given state_block_nid will appear
-// together in the list and those entries will sorted by event_type_nid
-// and event_state_key_nid. This property makes it easier to merge two
-// state data blocks together.
const bulkSelectStateBlockEntriesSQL = "" +
- "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
- " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" +
- " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
-
-// Bulk state lookup by numeric state block ID.
-// Filters the rows in each block to the requested types and state keys.
-// We would like to restrict to particular type state key pairs but we are
-// restricted by the query language to pull the cross product of a list
-// of types and a list state_keys. So we have to filter the result in the
-// application to restrict it to the list of event types and state keys we
-// actually wanted.
-const bulkSelectFilteredStateBlockEntriesSQL = "" +
- "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
- " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" +
- " AND event_type_nid = ANY($2) AND event_state_key_nid = ANY($3)" +
- " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
+ "SELECT state_block_nid, event_nids" +
+ " FROM roomserver_state_block WHERE state_block_nid = ANY($1)"
type stateBlockStatements struct {
- insertStateDataStmt *sql.Stmt
- selectNextStateBlockNIDStmt *sql.Stmt
- bulkSelectStateBlockEntriesStmt *sql.Stmt
- bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
+ insertStateDataStmt *sql.Stmt
+ bulkSelectStateBlockEntriesStmt *sql.Stmt
}
-func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
- s := &stateBlockStatements{}
+func createStateBlockTable(db *sql.DB) error {
_, err := db.Exec(stateDataSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
+ s := &stateBlockStatements{}
return s, shared.StatementList{
{&s.insertStateDataStmt, insertStateDataSQL},
- {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
- {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
}.Prepare(db)
}
func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context,
txn *sql.Tx,
- entries []types.StateEntry,
-) (types.StateBlockNID, error) {
- stateBlockNID, err := s.selectNextStateBlockNID(ctx)
- if err != nil {
- return 0, err
- }
- for _, entry := range entries {
- _, err := s.insertStateDataStmt.ExecContext(
- ctx,
- int64(stateBlockNID),
- int64(entry.EventTypeNID),
- int64(entry.EventStateKeyNID),
- int64(entry.EventNID),
- )
- if err != nil {
- return 0, err
- }
+ entries types.StateEntries,
+) (id types.StateBlockNID, err error) {
+ entries = entries[:util.SortAndUnique(entries)]
+ var nids types.EventNIDs
+ for _, e := range entries {
+ nids = append(nids, e.EventNID)
}
- return stateBlockNID, nil
-}
-
-func (s *stateBlockStatements) selectNextStateBlockNID(
- ctx context.Context,
-) (types.StateBlockNID, error) {
- var stateBlockNID int64
- err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID)
- return types.StateBlockNID(stateBlockNID), err
+ err = s.insertStateDataStmt.QueryRowContext(
+ ctx, nids.Hash(), eventNIDsAsArray(nids),
+ ).Scan(&id)
+ return
}
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
- ctx context.Context, stateBlockNIDs []types.StateBlockNID,
-) ([]types.StateEntryList, error) {
- nids := make([]int64, len(stateBlockNIDs))
- for i := range stateBlockNIDs {
- nids[i] = int64(stateBlockNIDs[i])
- }
- rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids))
+ ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
+) ([][]types.EventNID, error) {
+ rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
- results := make([]types.StateEntryList, len(stateBlockNIDs))
- // current is a pointer to the StateEntryList to append the state entries to.
- var current *types.StateEntryList
+ results := make([][]types.EventNID, len(stateBlockNIDs))
i := 0
- for rows.Next() {
- var (
- stateBlockNID int64
- eventTypeNID int64
- eventStateKeyNID int64
- eventNID int64
- entry types.StateEntry
- )
- if err = rows.Scan(
- &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
- ); err != nil {
+ for ; rows.Next(); i++ {
+ var stateBlockNID types.StateBlockNID
+ var result pq.Int64Array
+ if err = rows.Scan(&stateBlockNID, &result); err != nil {
return nil, err
}
- entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
- entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
- entry.EventNID = types.EventNID(eventNID)
- if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
- // The state entry row is for a different state data block to the current one.
- // So we start appending to the next entry in the list.
- current = &results[i]
- current.StateBlockNID = types.StateBlockNID(stateBlockNID)
- i++
+ r := []types.EventNID{}
+ for _, e := range result {
+ r = append(r, types.EventNID(e))
}
- current.StateEntries = append(current.StateEntries, entry)
+ results[i] = r
}
if err = rows.Err(); err != nil {
return nil, err
@@ -186,71 +133,6 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
return results, err
}
-func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
- ctx context.Context,
- stateBlockNIDs []types.StateBlockNID,
- stateKeyTuples []types.StateKeyTuple,
-) ([]types.StateEntryList, error) {
- tuples := stateKeyTupleSorter(stateKeyTuples)
- // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
- sort.Sort(tuples)
-
- eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
- rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext(
- ctx,
- stateBlockNIDsAsArray(stateBlockNIDs),
- eventTypeNIDArray,
- eventStateKeyNIDArray,
- )
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed")
-
- var results []types.StateEntryList
- var current types.StateEntryList
- for rows.Next() {
- var (
- stateBlockNID int64
- eventTypeNID int64
- eventStateKeyNID int64
- eventNID int64
- entry types.StateEntry
- )
- if err := rows.Scan(
- &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
- ); err != nil {
- return nil, err
- }
- entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
- entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
- entry.EventNID = types.EventNID(eventNID)
-
- // We can use binary search here because we sorted the tuples earlier
- if !tuples.contains(entry.StateKeyTuple) {
- // The select will return the cross product of types and state keys.
- // So we need to check if type of the entry is in the list.
- continue
- }
-
- if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
- // The state entry row is for a different state data block to the current one.
- // So we append the current entry to the results and start adding to a new one.
- // The first time through the loop current will be empty.
- if current.StateEntries != nil {
- results = append(results, current)
- }
- current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
- }
- current.StateEntries = append(current.StateEntries, entry)
- }
- // Add the last entry to the list if it is not empty.
- if current.StateEntries != nil {
- results = append(results, current)
- }
- return results, rows.Err()
-}
-
func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go
index 63175955..15e14e2e 100644
--- a/roomserver/storage/postgres/state_snapshot_table.go
+++ b/roomserver/storage/postgres/state_snapshot_table.go
@@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/util"
)
const stateSnapshotSchema = `
@@ -40,19 +41,29 @@ const stateSnapshotSchema = `
-- the full state under single state_block_nid.
CREATE SEQUENCE IF NOT EXISTS roomserver_state_snapshot_nid_seq;
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
- -- Local numeric ID for the state.
- state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_seq'),
- -- Local numeric ID of the room this state is for.
- -- Unused in normal operation, but useful for background work or ad-hoc debugging.
- room_nid bigint NOT NULL,
- -- List of state_block_nids, stored sorted by state_block_nid.
- state_block_nids bigint[] NOT NULL
+ -- The state snapshot NID that identifies this snapshot.
+ state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_seq'),
+ -- The hash of the state snapshot, which is used to enforce uniqueness. The hash is
+ -- generated in Dendrite and passed through to the database, as a btree index over
+ -- this column is cheap and fits within the maximum index size.
+ state_snapshot_hash BYTEA UNIQUE,
+ -- The room NID that the snapshot belongs to.
+ room_nid bigint NOT NULL,
+ -- The state blocks contained within this snapshot.
+ state_block_nids bigint[] NOT NULL
);
`
+// Insert a new state snapshot. If we conflict on the hash column then
+// we must perform an update so that the RETURNING statement returns the
+// ID of the row that we conflicted with, so that we can then refer to
+// the original snapshot.
const insertStateSQL = "" +
- "INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)" +
- " VALUES ($1, $2)" +
+ "INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids)" +
+ " VALUES ($1, $2, $3)" +
+ " ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2" +
+ // Performing an update, above, ensures that the RETURNING statement
+ // below will always return a valid state snapshot ID
" RETURNING state_snapshot_nid"
// Bulk state data NID lookup.
@@ -67,12 +78,13 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
-func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
- s := &stateSnapshotStatements{}
+func createStateSnapshotTable(db *sql.DB) error {
_, err := db.Exec(stateSnapshotSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
+ s := &stateSnapshotStatements{}
return s, shared.StatementList{
{&s.insertStateStmt, insertStateSQL},
@@ -81,13 +93,15 @@ func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
}
func (s *stateSnapshotStatements) InsertState(
- ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
+ ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
) (stateNID types.StateSnapshotNID, err error) {
- nids := make([]int64, len(stateBlockNIDs))
- for i := range stateBlockNIDs {
- nids[i] = int64(stateBlockNIDs[i])
+ nids = nids[:util.SortAndUnique(nids)]
+ var id int64
+ err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id)
+ if err != nil {
+ return 0, err
}
- err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
+ stateNID = types.StateSnapshotNID(id)
return
}
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go
index bb3f841d..863a1593 100644
--- a/roomserver/storage/postgres/storage.go
+++ b/roomserver/storage/postgres/storage.go
@@ -17,6 +17,7 @@ package postgres
import (
"database/sql"
+ "fmt"
// Import the postgres database driver.
_ "github.com/lib/pq"
@@ -39,20 +40,25 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
var db *sql.DB
var err error
if db, err = sqlutil.Open(dbProperties); err != nil {
- return nil, err
+ return nil, fmt.Errorf("sqlutil.Open: %w", err)
}
- // Create tables before executing migrations so we don't fail if the table is missing,
- // and THEN prepare statements so we don't fail due to referencing new columns
- ms := membershipStatements{}
- if err := ms.execSchema(db); err != nil {
+ // Create the tables.
+ if err := d.create(db); err != nil {
return nil, err
}
+
+ // Then execute the migrations. By this point the tables are created with the latest
+ // schemas.
m := sqlutil.NewMigrations()
deltas.LoadAddForgottenColumn(m)
+ deltas.LoadStateBlocksRefactor(m)
if err := m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
+
+ // Then prepare the statements. Now that the migrations have run, any columns referred
+ // to in the database code should now exist.
if err := d.prepare(db, cache); err != nil {
return nil, err
}
@@ -60,61 +66,107 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
return &d, nil
}
-// nolint: gocyclo
-func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) (err error) {
- eventStateKeys, err := NewPostgresEventStateKeysTable(db)
+func (d *Database) create(db *sql.DB) error {
+ if err := createEventStateKeysTable(db); err != nil {
+ return err
+ }
+ if err := createEventTypesTable(db); err != nil {
+ return err
+ }
+ if err := createEventJSONTable(db); err != nil {
+ return err
+ }
+ if err := createEventsTable(db); err != nil {
+ return err
+ }
+ if err := createRoomsTable(db); err != nil {
+ return err
+ }
+ if err := createTransactionsTable(db); err != nil {
+ return err
+ }
+ if err := createStateBlockTable(db); err != nil {
+ return err
+ }
+ if err := createStateSnapshotTable(db); err != nil {
+ return err
+ }
+ if err := createPrevEventsTable(db); err != nil {
+ return err
+ }
+ if err := createRoomAliasesTable(db); err != nil {
+ return err
+ }
+ if err := createInvitesTable(db); err != nil {
+ return err
+ }
+ if err := createMembershipTable(db); err != nil {
+ return err
+ }
+ if err := createPublishedTable(db); err != nil {
+ return err
+ }
+ if err := createRedactionsTable(db); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
+ eventStateKeys, err := prepareEventStateKeysTable(db)
if err != nil {
return err
}
- eventTypes, err := NewPostgresEventTypesTable(db)
+ eventTypes, err := prepareEventTypesTable(db)
if err != nil {
return err
}
- eventJSON, err := NewPostgresEventJSONTable(db)
+ eventJSON, err := prepareEventJSONTable(db)
if err != nil {
return err
}
- events, err := NewPostgresEventsTable(db)
+ events, err := prepareEventsTable(db)
if err != nil {
return err
}
- rooms, err := NewPostgresRoomsTable(db)
+ rooms, err := prepareRoomsTable(db)
if err != nil {
return err
}
- transactions, err := NewPostgresTransactionsTable(db)
+ transactions, err := prepareTransactionsTable(db)
if err != nil {
return err
}
- stateBlock, err := NewPostgresStateBlockTable(db)
+ stateBlock, err := prepareStateBlockTable(db)
if err != nil {
return err
}
- stateSnapshot, err := NewPostgresStateSnapshotTable(db)
+ stateSnapshot, err := prepareStateSnapshotTable(db)
if err != nil {
return err
}
- roomAliases, err := NewPostgresRoomAliasesTable(db)
+ prevEvents, err := preparePrevEventsTable(db)
if err != nil {
return err
}
- prevEvents, err := NewPostgresPreviousEventsTable(db)
+ roomAliases, err := prepareRoomAliasesTable(db)
if err != nil {
return err
}
- invites, err := NewPostgresInvitesTable(db)
+ invites, err := prepareInvitesTable(db)
if err != nil {
return err
}
- membership, err := NewPostgresMembershipTable(db)
+ membership, err := prepareMembershipTable(db)
if err != nil {
return err
}
- published, err := NewPostgresPublishedTable(db)
+ published, err := preparePublishedTable(db)
if err != nil {
return err
}
- redactions, err := NewPostgresRedactionsTable(db)
+ redactions, err := prepareRedactionsTable(db)
if err != nil {
return err
}
diff --git a/roomserver/storage/postgres/transactions_table.go b/roomserver/storage/postgres/transactions_table.go
index 5e59ae16..cada0d8a 100644
--- a/roomserver/storage/postgres/transactions_table.go
+++ b/roomserver/storage/postgres/transactions_table.go
@@ -54,12 +54,13 @@ type transactionStatements struct {
selectTransactionEventIDStmt *sql.Stmt
}
-func NewPostgresTransactionsTable(db *sql.DB) (tables.Transactions, error) {
- s := &transactionStatements{}
+func createTransactionsTable(db *sql.DB) error {
_, err := db.Exec(transactionsSchema)
- if err != nil {
- return nil, err
- }
+ return err
+}
+
+func prepareTransactionsTable(db *sql.DB) (tables.Transactions, error) {
+ s := &transactionStatements{}
return s, shared.StatementList{
{&s.insertTransactionStmt, insertTransactionSQL},
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 24b48772..096d5d7a 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -118,9 +118,24 @@ func (d *Database) StateEntriesForTuples(
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
- return d.StateBlockTable.BulkSelectFilteredStateBlockEntries(
- ctx, stateBlockNIDs, stateKeyTuples,
+ entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
+ ctx, stateBlockNIDs,
)
+ if err != nil {
+ return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
+ }
+ lists := []types.StateEntryList{}
+ for i, entry := range entries {
+ entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples)
+ if err != nil {
+ return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
+ }
+ lists = append(lists, types.StateEntryList{
+ StateBlockNID: stateBlockNIDs[i],
+ StateEntries: entries,
+ })
+ }
+ return lists, nil
}
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
@@ -141,8 +156,28 @@ func (d *Database) AddState(
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
+ if len(stateBlockNIDs) > 0 {
+ // Check to see if the event already appears in any of the existing state
+ // blocks. If it does then we should not add it again, as this will just
+ // result in excess state blocks and snapshots.
+ // TODO: Investigate why this is happening - probably input_events.go!
+ blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
+ if berr != nil {
+ return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr)
+ }
+ for i := len(state) - 1; i >= 0; i-- {
+ for _, events := range blocks {
+ for _, event := range events {
+ if state[i].EventNID == event {
+ state = append(state[:i], state[i+1:]...)
+ }
+ }
+ }
+ }
+ }
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if len(state) > 0 {
+ // If there's any state left to add then let's add new blocks.
var stateBlockNID types.StateBlockNID
stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state)
if err != nil {
@@ -237,7 +272,24 @@ func (d *Database) StateBlockNIDs(
func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
- return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
+ entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
+ ctx, stateBlockNIDs,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
+ }
+ lists := make([]types.StateEntryList, 0, len(entries))
+ for i, entry := range entries {
+ eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil)
+ if err != nil {
+ return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
+ }
+ lists = append(lists, types.StateEntryList{
+ StateBlockNID: stateBlockNIDs[i],
+ StateEntries: eventNIDs,
+ })
+ }
+ return lists, nil
}
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
diff --git a/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go b/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go
index 33fe9e2a..d08ab02d 100644
--- a/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go
+++ b/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go
@@ -24,6 +24,7 @@ import (
func LoadFromGoose() {
goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
+ goose.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
}
func LoadAddForgottenColumn(m *sqlutil.Migrations) {
diff --git a/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go
new file mode 100644
index 00000000..56158545
--- /dev/null
+++ b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go
@@ -0,0 +1,168 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
+)
+
+func LoadStateBlocksRefactor(m *sqlutil.Migrations) {
+ m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
+}
+
+func UpStateBlocksRefactor(tx *sql.Tx) error {
+ logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!")
+ defer logrus.Warn("State storage upgrade complete")
+
+ var maxsnapshotid int
+ var maxblockid int
+ if err := tx.QueryRow(`SELECT IFNULL(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil {
+ return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err)
+ }
+ if err := tx.QueryRow(`SELECT IFNULL(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil {
+ return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err)
+ }
+ maxsnapshotid++
+ maxblockid++
+
+ if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil {
+ return fmt.Errorf("tx.Exec: %w", err)
+ }
+ if _, err := tx.Exec(`ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil {
+ return fmt.Errorf("tx.Exec: %w", err)
+ }
+ _, err := tx.Exec(`
+ CREATE TABLE IF NOT EXISTS roomserver_state_block (
+ state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT,
+ state_block_hash BLOB UNIQUE,
+ event_nids TEXT NOT NULL DEFAULT '[]'
+ );
+ `)
+ if err != nil {
+ return fmt.Errorf("tx.Exec: %w", err)
+ }
+ _, err = tx.Exec(`
+ CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
+ state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
+ state_snapshot_hash BLOB UNIQUE,
+ room_nid INTEGER NOT NULL,
+ state_block_nids TEXT NOT NULL DEFAULT '[]'
+ );
+ `)
+ if err != nil {
+ return fmt.Errorf("tx.Exec: %w", err)
+ }
+ snapshotrows, err := tx.Query(`SELECT state_snapshot_nid, room_nid, state_block_nids FROM _roomserver_state_snapshots;`)
+ if err != nil {
+ return fmt.Errorf("tx.Query: %w", err)
+ }
+ defer internal.CloseAndLogIfError(context.TODO(), snapshotrows, "rows.close() failed")
+ for snapshotrows.Next() {
+ var snapshot types.StateSnapshotNID
+ var room types.RoomNID
+ var jsonblocks string
+ var blocks []types.StateBlockNID
+ if err = snapshotrows.Scan(&snapshot, &room, &jsonblocks); err != nil {
+ return fmt.Errorf("rows.Scan: %w", err)
+ }
+ if err = json.Unmarshal([]byte(jsonblocks), &blocks); err != nil {
+ return fmt.Errorf("json.Unmarshal: %w", err)
+ }
+
+ var newblocks types.StateBlockNIDs
+ for _, block := range blocks {
+ if err = func() error {
+ blockrows, berr := tx.Query(`SELECT event_nid FROM _roomserver_state_block WHERE state_block_nid = $1`, block)
+ if berr != nil {
+ return fmt.Errorf("tx.Query (event nids from old block): %w", berr)
+ }
+ defer internal.CloseAndLogIfError(context.TODO(), blockrows, "rows.close() failed")
+ events := types.EventNIDs{}
+ for blockrows.Next() {
+ var event types.EventNID
+ if err = blockrows.Scan(&event); err != nil {
+ return fmt.Errorf("rows.Scan: %w", err)
+ }
+ events = append(events, event)
+ }
+ events = events[:util.SortAndUnique(events)]
+ eventjson, eerr := json.Marshal(events)
+ if eerr != nil {
+ return fmt.Errorf("json.Marshal: %w", eerr)
+ }
+
+ var blocknid types.StateBlockNID
+ err = tx.QueryRow(`
+ INSERT INTO roomserver_state_block (state_block_nid, state_block_hash, event_nids)
+ VALUES ($1, $2, $3)
+ ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$3
+ RETURNING state_block_nid
+ `, maxblockid, events.Hash(), eventjson).Scan(&blocknid)
+ if err != nil {
+ return fmt.Errorf("tx.QueryRow.Scan (insert new block): %w", err)
+ }
+ maxblockid++
+ newblocks = append(newblocks, blocknid)
+ return nil
+ }(); err != nil {
+ return err
+ }
+
+ newblocksjson, jerr := json.Marshal(newblocks)
+ if jerr != nil {
+ return fmt.Errorf("json.Marshal (new blocks): %w", jerr)
+ }
+ var newsnapshot types.StateSnapshotNID
+ err = tx.QueryRow(`
+ INSERT INTO roomserver_state_snapshots (state_snapshot_nid, state_snapshot_hash, room_nid, state_block_nids)
+ VALUES ($1, $2, $3, $4)
+ ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$3
+ RETURNING state_snapshot_nid
+ `, maxsnapshotid, newblocks.Hash(), room, newblocksjson).Scan(&newsnapshot)
+ if err != nil {
+ return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err)
+ }
+ maxsnapshotid++
+ if _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$2`, newsnapshot, snapshot, maxsnapshotid); err != nil {
+ return fmt.Errorf("tx.Exec (update events): %w", err)
+ }
+ if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$2`, newsnapshot, snapshot, maxsnapshotid); err != nil {
+ return fmt.Errorf("tx.Exec (update rooms): %w", err)
+ }
+ }
+ }
+
+ if _, err = tx.Exec(`DROP TABLE _roomserver_state_snapshots;`); err != nil {
+ return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err)
+ }
+ if _, err = tx.Exec(`DROP TABLE _roomserver_state_block;`); err != nil {
+ return fmt.Errorf("tx.Exec (delete old block table): %w", err)
+ }
+
+ return nil
+}
+
+func DownStateBlocksRefactor(tx *sql.Tx) error {
+ panic("Downgrading state storage is not supported")
+}
diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go
index 3cd44b1d..29d54b83 100644
--- a/roomserver/storage/sqlite3/event_json_table.go
+++ b/roomserver/storage/sqlite3/event_json_table.go
@@ -53,14 +53,16 @@ type eventJSONStatements struct {
bulkSelectEventJSONStmt *sql.Stmt
}
-func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
+func createEventJSONTable(db *sql.DB) error {
+ _, err := db.Exec(eventJSONSchema)
+ return err
+}
+
+func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
s := &eventJSONStatements{
db: db,
}
- _, err := db.Exec(eventJSONSchema)
- if err != nil {
- return nil, err
- }
+
return s, shared.StatementList{
{&s.insertEventJSONStmt, insertEventJSONSQL},
{&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL},
diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go
index 345df8c6..d430e553 100644
--- a/roomserver/storage/sqlite3/event_state_keys_table.go
+++ b/roomserver/storage/sqlite3/event_state_keys_table.go
@@ -70,14 +70,16 @@ type eventStateKeyStatements struct {
bulkSelectEventStateKeyStmt *sql.Stmt
}
-func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
+func createEventStateKeysTable(db *sql.DB) error {
+ _, err := db.Exec(eventStateKeysSchema)
+ return err
+}
+
+func prepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
s := &eventStateKeyStatements{
db: db,
}
- _, err := db.Exec(eventStateKeysSchema)
- if err != nil {
- return nil, err
- }
+
return s, shared.StatementList{
{&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
{&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go
index 26e2bf84..694f4e21 100644
--- a/roomserver/storage/sqlite3/event_types_table.go
+++ b/roomserver/storage/sqlite3/event_types_table.go
@@ -85,14 +85,15 @@ type eventTypeStatements struct {
bulkSelectEventTypeNIDStmt *sql.Stmt
}
-func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
+func createEventTypesTable(db *sql.DB) error {
+ _, err := db.Exec(eventTypesSchema)
+ return err
+}
+
+func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
s := &eventTypeStatements{
db: db,
}
- _, err := db.Exec(eventTypesSchema)
- if err != nil {
- return nil, err
- }
return s, shared.StatementList{
{&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go
index 53269657..e964770d 100644
--- a/roomserver/storage/sqlite3/events_table.go
+++ b/roomserver/storage/sqlite3/events_table.go
@@ -20,6 +20,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
+ "sort"
"strings"
"github.com/matrix-org/dendrite/internal"
@@ -63,6 +64,11 @@ const bulkSelectStateEventByIDSQL = "" +
" WHERE event_id IN ($1)" +
" ORDER BY event_type_nid, event_state_key_nid ASC"
+const bulkSelectStateEventByNIDSQL = "" +
+ "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
+ " WHERE event_nid IN ($1)"
+ // Rest of query is built by BulkSelectStateEventByNID
+
const bulkSelectStateAtEventByIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
" WHERE event_id IN ($1)"
@@ -115,14 +121,15 @@ type eventStatements struct {
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
}
-func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
+func createEventsTable(db *sql.DB) error {
+ _, err := db.Exec(eventsSchema)
+ return err
+}
+
+func prepareEventsTable(db *sql.DB) (tables.Events, error) {
s := &eventStatements{
db: db,
}
- _, err := db.Exec(eventsSchema)
- if err != nil {
- return nil, err
- }
return s, shared.StatementList{
{&s.insertEventStmt, insertEventSQL},
@@ -232,6 +239,61 @@ func (s *eventStatements) BulkSelectStateEventByID(
return results, err
}
+// bulkSelectStateEventByID lookups a list of state events by event ID.
+// If any of the requested events are missing from the database it returns a types.MissingEventError
+func (s *eventStatements) BulkSelectStateEventByNID(
+ ctx context.Context, eventNIDs []types.EventNID,
+ stateKeyTuples []types.StateKeyTuple,
+) ([]types.StateEntry, error) {
+ tuples := stateKeyTupleSorter(stateKeyTuples)
+ sort.Sort(tuples)
+ eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
+ params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray))
+ selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
+ for _, v := range eventNIDs {
+ params = append(params, v)
+ }
+ if len(eventTypeNIDArray) > 0 {
+ selectOrig += " AND event_type_nid IN " + sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(params))
+ for _, v := range eventTypeNIDArray {
+ params = append(params, v)
+ }
+ }
+ if len(eventStateKeyNIDArray) > 0 {
+ selectOrig += " AND event_state_key_nid IN " + sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(params))
+ for _, v := range eventStateKeyNIDArray {
+ params = append(params, v)
+ }
+ }
+ selectOrig += " ORDER BY event_type_nid, event_state_key_nid ASC"
+ selectStmt, err := s.db.Prepare(selectOrig)
+ if err != nil {
+ return nil, fmt.Errorf("s.db.Prepare: %w", err)
+ }
+ rows, err := selectStmt.QueryContext(ctx, params...)
+ if err != nil {
+ return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed")
+ // We know that we will only get as many results as event IDs
+ // because of the unique constraint on event IDs.
+ // So we can allocate an array of the correct size now.
+ // We might get fewer results than IDs so we adjust the length of the slice before returning it.
+ results := make([]types.StateEntry, len(eventNIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ result := &results[i]
+ if err = rows.Scan(
+ &result.EventTypeNID,
+ &result.EventStateKeyNID,
+ &result.EventNID,
+ ); err != nil {
+ return nil, err
+ }
+ }
+ return results[:i], err
+}
+
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go
index 327be6a0..e1aa1ebd 100644
--- a/roomserver/storage/sqlite3/invite_table.go
+++ b/roomserver/storage/sqlite3/invite_table.go
@@ -70,14 +70,15 @@ type inviteStatements struct {
selectInvitesAboutToRetireStmt *sql.Stmt
}
-func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
+func createInvitesTable(db *sql.DB) error {
+ _, err := db.Exec(inviteSchema)
+ return err
+}
+
+func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
s := &inviteStatements{
db: db,
}
- _, err := db.Exec(inviteSchema)
- if err != nil {
- return nil, err
- }
return s, shared.StatementList{
{&s.insertInviteEventStmt, insertInviteEventSQL},
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
index d716ced0..d9fe32cf 100644
--- a/roomserver/storage/sqlite3/membership_table.go
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -115,7 +115,12 @@ type membershipStatements struct {
updateMembershipForgetRoomStmt *sql.Stmt
}
-func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
+func createMembershipTable(db *sql.DB) error {
+ _, err := db.Exec(membershipSchema)
+ return err
+}
+
+func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
s := &membershipStatements{
db: db,
}
@@ -135,11 +140,6 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
}.Prepare(db)
}
-func (s *membershipStatements) execSchema(db *sql.DB) error {
- _, err := db.Exec(membershipSchema)
- return err
-}
-
func (s *membershipStatements) InsertMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go
index aaee6273..3cb52767 100644
--- a/roomserver/storage/sqlite3/previous_events_table.go
+++ b/roomserver/storage/sqlite3/previous_events_table.go
@@ -71,14 +71,15 @@ type previousEventStatements struct {
selectPreviousEventExistsStmt *sql.Stmt
}
-func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
+func createPrevEventsTable(db *sql.DB) error {
+ _, err := db.Exec(previousEventSchema)
+ return err
+}
+
+func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
s := &previousEventStatements{
db: db,
}
- _, err := db.Exec(previousEventSchema)
- if err != nil {
- return nil, err
- }
return s, shared.StatementList{
{&s.insertPreviousEventStmt, insertPreviousEventSQL},
diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go
index dcf6f697..6d9d9135 100644
--- a/roomserver/storage/sqlite3/published_table.go
+++ b/roomserver/storage/sqlite3/published_table.go
@@ -50,14 +50,16 @@ type publishedStatements struct {
selectPublishedStmt *sql.Stmt
}
-func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
+func createPublishedTable(db *sql.DB) error {
+ _, err := db.Exec(publishedSchema)
+ return err
+}
+
+func preparePublishedTable(db *sql.DB) (tables.Published, error) {
s := &publishedStatements{
db: db,
}
- _, err := db.Exec(publishedSchema)
- if err != nil {
- return nil, err
- }
+
return s, shared.StatementList{
{&s.upsertPublishedStmt, upsertPublishedSQL},
{&s.selectAllPublishedStmt, selectAllPublishedSQL},
diff --git a/roomserver/storage/sqlite3/redactions_table.go b/roomserver/storage/sqlite3/redactions_table.go
index e6471486..b3498182 100644
--- a/roomserver/storage/sqlite3/redactions_table.go
+++ b/roomserver/storage/sqlite3/redactions_table.go
@@ -59,14 +59,15 @@ type redactionStatements struct {
markRedactionValidatedStmt *sql.Stmt
}
-func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
+func createRedactionsTable(db *sql.DB) error {
+ _, err := db.Exec(redactionsSchema)
+ return err
+}
+
+func prepareRedactionsTable(db *sql.DB) (tables.Redactions, error) {
s := &redactionStatements{
db: db,
}
- _, err := db.Exec(redactionsSchema)
- if err != nil {
- return nil, err
- }
return s, shared.StatementList{
{&s.insertRedactionStmt, insertRedactionSQL},
diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go
index f053e398..5215fa6f 100644
--- a/roomserver/storage/sqlite3/room_aliases_table.go
+++ b/roomserver/storage/sqlite3/room_aliases_table.go
@@ -64,14 +64,16 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt
}
-func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
+func createRoomAliasesTable(db *sql.DB) error {
+ _, err := db.Exec(roomAliasesSchema)
+ return err
+}
+
+func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
s := &roomAliasesStatements{
db: db,
}
- _, err := db.Exec(roomAliasesSchema)
- if err != nil {
- return nil, err
- }
+
return s, shared.StatementList{
{&s.insertRoomAliasStmt, insertRoomAliasSQL},
{&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go
index fe8e601f..534a870c 100644
--- a/roomserver/storage/sqlite3/rooms_table.go
+++ b/roomserver/storage/sqlite3/rooms_table.go
@@ -86,14 +86,16 @@ type roomStatements struct {
selectRoomIDsStmt *sql.Stmt
}
-func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
+func createRoomsTable(db *sql.DB) error {
+ _, err := db.Exec(roomsSchema)
+ return err
+}
+
+func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
s := &roomStatements{
db: db,
}
- _, err := db.Exec(roomsSchema)
- if err != nil {
- return nil, err
- }
+
return s, shared.StatementList{
{&s.insertRoomNIDStmt, insertRoomNIDSQL},
{&s.selectRoomNIDStmt, selectRoomNIDSQL},
diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go
index 2c544f2b..cfb2a49e 100644
--- a/roomserver/storage/sqlite3/state_block_table.go
+++ b/roomserver/storage/sqlite3/state_block_table.go
@@ -18,6 +18,7 @@ package sqlite3
import (
"context"
"database/sql"
+ "encoding/json"
"fmt"
"sort"
"strings"
@@ -32,228 +33,113 @@ import (
const stateDataSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state_block (
- state_block_nid INTEGER NOT NULL,
- event_type_nid INTEGER NOT NULL,
- event_state_key_nid INTEGER NOT NULL,
- event_nid INTEGER NOT NULL,
- UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
+ -- The state snapshot NID that identifies this snapshot.
+ state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT,
+ -- The hash of the state block, which is used to enforce uniqueness. The hash is
+ -- generated in Dendrite and passed through to the database, as a btree index over
+ -- this column is cheap and fits within the maximum index size.
+ state_block_hash BLOB UNIQUE,
+ -- The event NIDs contained within the state block, encoded as JSON.
+ event_nids TEXT NOT NULL DEFAULT '[]'
);
`
-const insertStateDataSQL = "" +
- "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
- " VALUES ($1, $2, $3, $4)"
-
-const selectNextStateBlockNIDSQL = `
-SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
+// Insert a new state block. If we conflict on the hash column then
+// we must perform an update so that the RETURNING statement returns the
+// ID of the row that we conflicted with, so that we can then refer to
+// the original block.
+const insertStateDataSQL = `
+ INSERT INTO roomserver_state_block (state_block_hash, event_nids)
+ VALUES ($1, $2)
+ ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2
+ RETURNING state_block_nid
`
-// Bulk state lookup by numeric state block ID.
-// Sort by the state_block_nid, event_type_nid, event_state_key_nid
-// This means that all the entries for a given state_block_nid will appear
-// together in the list and those entries will sorted by event_type_nid
-// and event_state_key_nid. This property makes it easier to merge two
-// state data blocks together.
const bulkSelectStateBlockEntriesSQL = "" +
- "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
- " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
- " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
-
-// Bulk state lookup by numeric state block ID.
-// Filters the rows in each block to the requested types and state keys.
-// We would like to restrict to particular type state key pairs but we are
-// restricted by the query language to pull the cross product of a list
-// of types and a list state_keys. So we have to filter the result in the
-// application to restrict it to the list of event types and state keys we
-// actually wanted.
-const bulkSelectFilteredStateBlockEntriesSQL = "" +
- "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
- " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
- " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
- " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
+ "SELECT state_block_nid, event_nids" +
+ " FROM roomserver_state_block WHERE state_block_nid IN ($1)"
type stateBlockStatements struct {
- db *sql.DB
- insertStateDataStmt *sql.Stmt
- selectNextStateBlockNIDStmt *sql.Stmt
- bulkSelectStateBlockEntriesStmt *sql.Stmt
- bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
+ db *sql.DB
+ insertStateDataStmt *sql.Stmt
+ bulkSelectStateBlockEntriesStmt *sql.Stmt
+}
+
+func createStateBlockTable(db *sql.DB) error {
+ _, err := db.Exec(stateDataSchema)
+ return err
}
-func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
+func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{
db: db,
}
- _, err := db.Exec(stateDataSchema)
- if err != nil {
- return nil, err
- }
return s, shared.StatementList{
{&s.insertStateDataStmt, insertStateDataSQL},
- {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
- {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
}.Prepare(db)
}
func (s *stateBlockStatements) BulkInsertStateData(
- ctx context.Context, txn *sql.Tx,
- entries []types.StateEntry,
-) (types.StateBlockNID, error) {
- if len(entries) == 0 {
- return 0, nil
- }
- var stateBlockNID types.StateBlockNID
- err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
+ ctx context.Context,
+ txn *sql.Tx,
+ entries types.StateEntries,
+) (id types.StateBlockNID, err error) {
+ entries = entries[:util.SortAndUnique(entries)]
+ var nids types.EventNIDs
+ for _, e := range entries {
+ nids = append(nids, e.EventNID)
+ }
+ js, err := json.Marshal(nids)
if err != nil {
- return 0, err
- }
- for _, entry := range entries {
- _, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext(
- ctx,
- int64(stateBlockNID),
- int64(entry.EventTypeNID),
- int64(entry.EventStateKeyNID),
- int64(entry.EventNID),
- )
- if err != nil {
- return 0, err
- }
+ return 0, fmt.Errorf("json.Marshal: %w", err)
}
- return stateBlockNID, err
+ err = s.insertStateDataStmt.QueryRowContext(
+ ctx, nids.Hash(), js,
+ ).Scan(&id)
+ return
}
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
- ctx context.Context, stateBlockNIDs []types.StateBlockNID,
-) ([]types.StateEntryList, error) {
- nids := make([]interface{}, len(stateBlockNIDs))
- for k, v := range stateBlockNIDs {
- nids[k] = v
+ ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
+) ([][]types.EventNID, error) {
+ intfs := make([]interface{}, len(stateBlockNIDs))
+ for i := range stateBlockNIDs {
+ intfs[i] = int64(stateBlockNIDs[i])
}
- selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
+ selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(intfs)), 1)
selectStmt, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
- rows, err := selectStmt.QueryContext(ctx, nids...)
+ rows, err := selectStmt.QueryContext(ctx, intfs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
- results := make([]types.StateEntryList, len(stateBlockNIDs))
- // current is a pointer to the StateEntryList to append the state entries to.
- var current *types.StateEntryList
+ results := make([][]types.EventNID, len(stateBlockNIDs))
i := 0
- for rows.Next() {
- var (
- stateBlockNID int64
- eventTypeNID int64
- eventStateKeyNID int64
- eventNID int64
- entry types.StateEntry
- )
- if err := rows.Scan(
- &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
- ); err != nil {
+ for ; rows.Next(); i++ {
+ var stateBlockNID types.StateBlockNID
+ var result json.RawMessage
+ if err = rows.Scan(&stateBlockNID, &result); err != nil {
return nil, err
}
- entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
- entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
- entry.EventNID = types.EventNID(eventNID)
- if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
- // The state entry row is for a different state data block to the current one.
- // So we start appending to the next entry in the list.
- current = &results[i]
- current.StateBlockNID = types.StateBlockNID(stateBlockNID)
- i++
+ r := []types.EventNID{}
+ if err = json.Unmarshal(result, &r); err != nil {
+ return nil, fmt.Errorf("json.Unmarshal: %w", err)
}
- current.StateEntries = append(current.StateEntries, entry)
- }
- if i != len(nids) {
- return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids))
+ results[i] = r
}
- return results, nil
-}
-
-func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
- ctx context.Context,
- stateBlockNIDs []types.StateBlockNID,
- stateKeyTuples []types.StateKeyTuple,
-) ([]types.StateEntryList, error) {
- tuples := stateKeyTupleSorter(stateKeyTuples)
- // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
- sort.Sort(tuples)
-
- eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
- sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1)
- sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
- sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
-
- var params []interface{}
- for _, val := range stateBlockNIDs {
- params = append(params, int64(val))
- }
- for _, val := range eventTypeNIDArray {
- params = append(params, val)
- }
- for _, val := range eventStateKeyNIDArray {
- params = append(params, val)
- }
-
- rows, err := s.db.QueryContext(
- ctx,
- sqlStatement,
- params...,
- )
- if err != nil {
+ if err = rows.Err(); err != nil {
return nil, err
}
- defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed")
-
- var results []types.StateEntryList
- var current types.StateEntryList
- for rows.Next() {
- var (
- stateBlockNID int64
- eventTypeNID int64
- eventStateKeyNID int64
- eventNID int64
- entry types.StateEntry
- )
- if err := rows.Scan(
- &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
- ); err != nil {
- return nil, err
- }
- entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
- entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
- entry.EventNID = types.EventNID(eventNID)
-
- // We can use binary search here because we sorted the tuples earlier
- if !tuples.contains(entry.StateKeyTuple) {
- // The select will return the cross product of types and state keys.
- // So we need to check if type of the entry is in the list.
- continue
- }
-
- if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
- // The state entry row is for a different state data block to the current one.
- // So we append the current entry to the results and start adding to a new one.
- // The first time through the loop current will be empty.
- if current.StateEntries != nil {
- results = append(results, current)
- }
- current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
- }
- current.StateEntries = append(current.StateEntries, entry)
- }
- // Add the last entry to the list if it is not empty.
- if current.StateEntries != nil {
- results = append(results, current)
+ if i != len(stateBlockNIDs) {
+ return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateBlockNIDs))
}
- return results, nil
+ return results, err
}
type stateKeyTupleSorter []types.StateKeyTuple
diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go
index bf49f62c..95cae99e 100644
--- a/roomserver/storage/sqlite3/state_snapshot_table.go
+++ b/roomserver/storage/sqlite3/state_snapshot_table.go
@@ -27,19 +27,34 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/util"
)
const stateSnapshotSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
+ -- The state snapshot NID that identifies this snapshot.
state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
+ -- The hash of the state snapshot, which is used to enforce uniqueness. The hash is
+ -- generated in Dendrite and passed through to the database, as a btree index over
+ -- this column is cheap and fits within the maximum index size.
+ state_snapshot_hash BLOB UNIQUE,
+ -- The room NID that the snapshot belongs to.
room_nid INTEGER NOT NULL,
+ -- The state blocks contained within this snapshot, encoded as JSON.
state_block_nids TEXT NOT NULL DEFAULT '[]'
);
`
+// Insert a new state snapshot. If we conflict on the hash column then
+// we must perform an update so that the RETURNING statement returns the
+// ID of the row that we conflicted with, so that we can then refer to
+// the original snapshot.
const insertStateSQL = `
- INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
- VALUES ($1, $2);`
+ INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids)
+ VALUES ($1, $2, $3)
+ ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2
+ RETURNING state_snapshot_nid
+`
// Bulk state data NID lookup.
// Sorting by state_snapshot_nid means we can use binary search over the result
@@ -54,14 +69,15 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
-func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
+func createStateSnapshotTable(db *sql.DB) error {
+ _, err := db.Exec(stateSnapshotSchema)
+ return err
+}
+
+func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{
db: db,
}
- _, err := db.Exec(stateSnapshotSchema)
- if err != nil {
- return nil, err
- }
return s, shared.StatementList{
{&s.insertStateStmt, insertStateSQL},
@@ -70,22 +86,20 @@ func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
}
func (s *stateSnapshotStatements) InsertState(
- ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
+ ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs,
) (stateNID types.StateSnapshotNID, err error) {
+ stateBlockNIDs = stateBlockNIDs[:util.SortAndUnique(stateBlockNIDs)]
stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs)
if err != nil {
return
}
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
- res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
- if err != nil {
- return 0, err
- }
- lastRowID, err := res.LastInsertId()
+ var id int64
+ err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&id)
if err != nil {
return 0, err
}
- stateNID = types.StateSnapshotNID(lastRowID)
+ stateNID = types.StateSnapshotNID(id)
return
}
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go
index 8e608a6d..c07ab507 100644
--- a/roomserver/storage/sqlite3/storage.go
+++ b/roomserver/storage/sqlite3/storage.go
@@ -53,17 +53,22 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
// which it will never obtain.
db.SetMaxOpenConns(20)
- // Create tables before executing migrations so we don't fail if the table is missing,
- // and THEN prepare statements so we don't fail due to referencing new columns
- ms := membershipStatements{}
- if err := ms.execSchema(db); err != nil {
+ // Create the tables.
+ if err := d.create(db); err != nil {
return nil, err
}
+
+ // Then execute the migrations. By this point the tables are created with the latest
+ // schemas.
m := sqlutil.NewMigrations()
deltas.LoadAddForgottenColumn(m)
+ deltas.LoadStateBlocksRefactor(m)
if err := m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
+
+ // Then prepare the statements. Now that the migrations have run, any columns referred
+ // to in the database code should now exist.
if err := d.prepare(db, cache); err != nil {
return nil, err
}
@@ -71,62 +76,107 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches)
return &d, nil
}
-// nolint: gocyclo
+func (d *Database) create(db *sql.DB) error {
+ if err := createEventStateKeysTable(db); err != nil {
+ return err
+ }
+ if err := createEventTypesTable(db); err != nil {
+ return err
+ }
+ if err := createEventJSONTable(db); err != nil {
+ return err
+ }
+ if err := createEventsTable(db); err != nil {
+ return err
+ }
+ if err := createRoomsTable(db); err != nil {
+ return err
+ }
+ if err := createTransactionsTable(db); err != nil {
+ return err
+ }
+ if err := createStateBlockTable(db); err != nil {
+ return err
+ }
+ if err := createStateSnapshotTable(db); err != nil {
+ return err
+ }
+ if err := createPrevEventsTable(db); err != nil {
+ return err
+ }
+ if err := createRoomAliasesTable(db); err != nil {
+ return err
+ }
+ if err := createInvitesTable(db); err != nil {
+ return err
+ }
+ if err := createMembershipTable(db); err != nil {
+ return err
+ }
+ if err := createPublishedTable(db); err != nil {
+ return err
+ }
+ if err := createRedactionsTable(db); err != nil {
+ return err
+ }
+
+ return nil
+}
+
func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
- var err error
- eventStateKeys, err := NewSqliteEventStateKeysTable(db)
+ eventStateKeys, err := prepareEventStateKeysTable(db)
if err != nil {
return err
}
- eventTypes, err := NewSqliteEventTypesTable(db)
+ eventTypes, err := prepareEventTypesTable(db)
if err != nil {
return err
}
- eventJSON, err := NewSqliteEventJSONTable(db)
+ eventJSON, err := prepareEventJSONTable(db)
if err != nil {
return err
}
- events, err := NewSqliteEventsTable(db)
+ events, err := prepareEventsTable(db)
if err != nil {
return err
}
- rooms, err := NewSqliteRoomsTable(db)
+ rooms, err := prepareRoomsTable(db)
if err != nil {
return err
}
- transactions, err := NewSqliteTransactionsTable(db)
+ transactions, err := prepareTransactionsTable(db)
if err != nil {
return err
}
- stateBlock, err := NewSqliteStateBlockTable(db)
+ stateBlock, err := prepareStateBlockTable(db)
if err != nil {
return err
}
- stateSnapshot, err := NewSqliteStateSnapshotTable(db)
+ stateSnapshot, err := prepareStateSnapshotTable(db)
if err != nil {
return err
}
- prevEvents, err := NewSqlitePrevEventsTable(db)
+ prevEvents, err := preparePrevEventsTable(db)
if err != nil {
return err
}
- roomAliases, err := NewSqliteRoomAliasesTable(db)
+ roomAliases, err := prepareRoomAliasesTable(db)
if err != nil {
return err
}
- invites, err := NewSqliteInvitesTable(db)
+ invites, err := prepareInvitesTable(db)
if err != nil {
return err
}
- membership, err := NewSqliteMembershipTable(db)
+ membership, err := prepareMembershipTable(db)
if err != nil {
return err
}
- published, err := NewSqlitePublishedTable(db)
+ published, err := preparePublishedTable(db)
if err != nil {
return err
}
- redactions, err := NewSqliteRedactionsTable(db)
+ redactions, err := prepareRedactionsTable(db)
if err != nil {
return err
}
diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go
index 029122c5..e7471d7b 100644
--- a/roomserver/storage/sqlite3/transactions_table.go
+++ b/roomserver/storage/sqlite3/transactions_table.go
@@ -49,14 +49,15 @@ type transactionStatements struct {
selectTransactionEventIDStmt *sql.Stmt
}
-func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
+func createTransactionsTable(db *sql.DB) error {
+ _, err := db.Exec(transactionsSchema)
+ return err
+}
+
+func prepareTransactionsTable(db *sql.DB) (tables.Transactions, error) {
s := &transactionStatements{
db: db,
}
- _, err := db.Exec(transactionsSchema)
- if err != nil {
- return nil, err
- }
return s, shared.StatementList{
{&s.insertTransactionStmt, insertTransactionSQL},
diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go
index 26bf5cf0..dd486873 100644
--- a/roomserver/storage/tables/interface.go
+++ b/roomserver/storage/tables/interface.go
@@ -43,6 +43,7 @@ type Events interface {
// bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
+ BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
@@ -81,14 +82,14 @@ type Transactions interface {
}
type StateSnapshot interface {
- InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error)
+ InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error)
BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
}
type StateBlock interface {
- BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries []types.StateEntry) (types.StateBlockNID, error)
- BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
- BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
+ BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error)
+ BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error)
+ //BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
}
type RoomAliases interface {
diff --git a/roomserver/types/types.go b/roomserver/types/types.go
index e866f6cb..d7e03ad6 100644
--- a/roomserver/types/types.go
+++ b/roomserver/types/types.go
@@ -16,9 +16,11 @@
package types
import (
+ "encoding/json"
"sort"
"github.com/matrix-org/gomatrixserverlib"
+ "golang.org/x/crypto/blake2b"
)
// EventTypeNID is a numeric ID for an event type.
@@ -40,6 +42,38 @@ type StateSnapshotNID int64
// These blocks of state data are combined to form the actual state.
type StateBlockNID int64
+// EventNIDs is used to sort and dedupe event NIDs.
+type EventNIDs []EventNID
+
+func (a EventNIDs) Len() int { return len(a) }
+func (a EventNIDs) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+func (a EventNIDs) Less(i, j int) bool { return a[i] < a[j] }
+
+func (a EventNIDs) Hash() []byte {
+ j, err := json.Marshal(a)
+ if err != nil {
+ return nil
+ }
+ h := blake2b.Sum256(j)
+ return h[:]
+}
+
+// StateBlockNIDs is used to sort and dedupe state block NIDs.
+type StateBlockNIDs []StateBlockNID
+
+func (a StateBlockNIDs) Len() int { return len(a) }
+func (a StateBlockNIDs) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+func (a StateBlockNIDs) Less(i, j int) bool { return a[i] < a[j] }
+
+func (a StateBlockNIDs) Hash() []byte {
+ j, err := json.Marshal(a)
+ if err != nil {
+ return nil
+ }
+ h := blake2b.Sum256(j)
+ return h[:]
+}
+
// A StateKeyTuple is a pair of a numeric event type and a numeric state key.
// It is used to lookup state entries.
type StateKeyTuple struct {
@@ -65,6 +99,12 @@ type StateEntry struct {
EventNID EventNID
}
+type StateEntries []StateEntry
+
+func (a StateEntries) Len() int { return len(a) }
+func (a StateEntries) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+func (a StateEntries) Less(i, j int) bool { return a[i].EventNID < a[j].EventNID }
+
// LessThan returns true if this state entry is less than the other state entry.
// The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries.
func (a StateEntry) LessThan(b StateEntry) bool {