diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2021-04-26 13:25:57 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-26 13:25:57 +0100 |
commit | 5ce1fe80dea8b8cfca8712e8d584deb995bbddcc (patch) | |
tree | 1307a1edf73abf68cebd4601efec1e467dac964c /roomserver | |
parent | d6e9b7b307ff0d7541046ec33890d49239c7a6ca (diff) |
State storage refactor (#1839)
* Hash-deduplicated state storage (and migrations) for PostgreSQL and SQLite
* Refactor droomserver database setup for migrations
* Fix conflict statements
* Update migration names
* Set a boundary for old to new block/snapshot IDs so we don't rewrite them more than once accidentally
* Create sequence if not exists
* Fix boundary queries
* Fix boundary queries
* Use Query
* Break out queries a bit
* More sequence tweaks
* Query parameters are not playing the game
* Injection escaping may not work for CREATE SEQUENCE after all
* Fix snapshot sequence name
* Use boundaried IDs in SQLite too
* Use IFNULL for SQLite
* Use COALESCE in PostgreSQL
* Review comments @Kegsay
Diffstat (limited to 'roomserver')
37 files changed, 1073 insertions, 551 deletions
diff --git a/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go b/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go index 733f0fa1..f3bd8632 100644 --- a/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go +++ b/roomserver/storage/postgres/deltas/20201028212440_add_forgotten_column.go @@ -24,6 +24,7 @@ import ( func LoadFromGoose() { goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) + goose.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor) } func LoadAddForgottenColumn(m *sqlutil.Migrations) { diff --git a/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go new file mode 100644 index 00000000..84da9614 --- /dev/null +++ b/roomserver/storage/postgres/deltas/2021041615092700_state_blocks_refactor.go @@ -0,0 +1,223 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +type stateSnapshotData struct { + StateSnapshotNID types.StateSnapshotNID + RoomNID types.RoomNID +} + +type stateBlockData struct { + stateSnapshotData + StateBlockNID types.StateBlockNID + EventNIDs types.EventNIDs +} + +func LoadStateBlocksRefactor(m *sqlutil.Migrations) { + m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor) +} + +// nolint:gocyclo +func UpStateBlocksRefactor(tx *sql.Tx) error { + logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!") + defer logrus.Warn("State storage upgrade complete") + + var snapshotcount int + var maxsnapshotid int + var maxblockid int + if err := tx.QueryRow(`SELECT COUNT(DISTINCT state_snapshot_nid) FROM roomserver_state_snapshots;`).Scan(&snapshotcount); err != nil { + return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) + } + if err := tx.QueryRow(`SELECT COALESCE(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil { + return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) + } + if err := tx.QueryRow(`SELECT COALESCE(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil { + return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) + } + maxsnapshotid++ + maxblockid++ + + if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + if _, err := tx.Exec(`ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + // We create new sequences starting with the maximum state snapshot and block NIDs. + // This means that all newly created snapshots and blocks by the migration will have + // NIDs higher than these values, so that when we come to update the references to + // these NIDs using UPDATE statements, we can guarantee we are only ever updating old + // values and not accidentally overwriting new ones. + if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE roomserver_state_block_nid_sequence START WITH %d;`, maxblockid)); err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE roomserver_state_snapshot_nid_sequence START WITH %d;`, maxsnapshotid)); err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + _, err := tx.Exec(` + CREATE TABLE IF NOT EXISTS roomserver_state_block ( + state_block_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_block_nid_sequence'), + state_block_hash BYTEA UNIQUE, + event_nids bigint[] NOT NULL + ); + `) + if err != nil { + return fmt.Errorf("tx.Exec (create blocks table): %w", err) + } + _, err = tx.Exec(` + CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( + state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_sequence'), + state_snapshot_hash BYTEA UNIQUE, + room_nid bigint NOT NULL, + state_block_nids bigint[] NOT NULL + ); + `) + if err != nil { + return fmt.Errorf("tx.Exec (create snapshots table): %w", err) + } + logrus.Warn("New tables created...") + + batchsize := 100 + for batchoffset := 0; batchoffset < snapshotcount; batchoffset += batchsize { + var snapshotrows *sql.Rows + snapshotrows, err = tx.Query(` + SELECT + state_snapshot_nid, + room_nid, + state_block_nid, + ARRAY_AGG(event_nid) AS event_nids + FROM ( + SELECT + _roomserver_state_snapshots.state_snapshot_nid, + _roomserver_state_snapshots.room_nid, + _roomserver_state_block.state_block_nid, + _roomserver_state_block.event_nid + FROM + _roomserver_state_snapshots + JOIN _roomserver_state_block ON _roomserver_state_block.state_block_nid = ANY (_roomserver_state_snapshots.state_block_nids) + WHERE + _roomserver_state_snapshots.state_snapshot_nid = ANY ( SELECT DISTINCT + _roomserver_state_snapshots.state_snapshot_nid + FROM + _roomserver_state_snapshots + LIMIT $1 OFFSET $2)) AS _roomserver_state_block + GROUP BY + state_snapshot_nid, + room_nid, + state_block_nid; + `, batchsize, batchoffset) + if err != nil { + return fmt.Errorf("tx.Query: %w", err) + } + + logrus.Warnf("Rewriting snapshots %d-%d of %d...", batchoffset, batchoffset+batchsize, snapshotcount) + var snapshots []stateBlockData + + for snapshotrows.Next() { + var snapshot stateBlockData + var eventsarray pq.Int64Array + if err = snapshotrows.Scan(&snapshot.StateSnapshotNID, &snapshot.RoomNID, &snapshot.StateBlockNID, &eventsarray); err != nil { + return fmt.Errorf("rows.Scan: %w", err) + } + for _, e := range eventsarray { + snapshot.EventNIDs = append(snapshot.EventNIDs, types.EventNID(e)) + } + snapshot.EventNIDs = snapshot.EventNIDs[:util.SortAndUnique(snapshot.EventNIDs)] + snapshots = append(snapshots, snapshot) + } + + if err = snapshotrows.Close(); err != nil { + return fmt.Errorf("snapshots.Close: %w", err) + } + + newsnapshots := map[stateSnapshotData]types.StateBlockNIDs{} + + for _, snapshot := range snapshots { + var eventsarray pq.Int64Array + for _, e := range snapshot.EventNIDs { + eventsarray = append(eventsarray, int64(e)) + } + + var blocknid types.StateBlockNID + err = tx.QueryRow(` + INSERT INTO roomserver_state_block (state_block_hash, event_nids) + VALUES ($1, $2) + ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2 + RETURNING state_block_nid + `, snapshot.EventNIDs.Hash(), eventsarray).Scan(&blocknid) + if err != nil { + return fmt.Errorf("tx.QueryRow.Scan (insert new block with %d events): %w", len(eventsarray), err) + } + index := stateSnapshotData{snapshot.StateSnapshotNID, snapshot.RoomNID} + newsnapshots[index] = append(newsnapshots[index], blocknid) + } + + for snapshotdata, newblocks := range newsnapshots { + var newblocksarray pq.Int64Array + for _, b := range newblocks { + newblocksarray = append(newblocksarray, int64(b)) + } + + var newNID types.StateSnapshotNID + err = tx.QueryRow(` + INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids) + VALUES ($1, $2, $3) + ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2 + RETURNING state_snapshot_nid + `, newblocks.Hash(), snapshotdata.RoomNID, newblocksarray).Scan(&newNID) + if err != nil { + return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err) + } + + if _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil { + return fmt.Errorf("tx.Exec (update events): %w", err) + } + + if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil { + return fmt.Errorf("tx.Exec (update rooms): %w", err) + } + } + } + + if _, err = tx.Exec(` + DROP TABLE _roomserver_state_snapshots; + DROP SEQUENCE roomserver_state_snapshot_nid_seq; + `); err != nil { + return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err) + } + if _, err = tx.Exec(` + DROP TABLE _roomserver_state_block; + DROP SEQUENCE roomserver_state_block_nid_seq; + `); err != nil { + return fmt.Errorf("tx.Exec (delete old block table): %w", err) + } + + return nil +} + +func DownStateBlocksRefactor(tx *sql.Tx) error { + panic("Downgrading state storage is not supported") +} diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 8f11d1d8..e0976b12 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -59,12 +59,14 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func NewPostgresEventJSONTable(db *sql.DB) (tables.EventJSON, error) { - s := &eventJSONStatements{} +func createEventJSONTable(db *sql.DB) error { _, err := db.Exec(eventJSONSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { + s := &eventJSONStatements{} + return s, shared.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go index 500ff20e..61682356 100644 --- a/roomserver/storage/postgres/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -77,12 +77,14 @@ type eventStateKeyStatements struct { bulkSelectEventStateKeyStmt *sql.Stmt } -func NewPostgresEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { - s := &eventStateKeyStatements{} +func createEventStateKeysTable(db *sql.DB) error { _, err := db.Exec(eventStateKeysSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { + s := &eventStateKeyStatements{} + return s, shared.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go index 02d6ad07..f4257850 100644 --- a/roomserver/storage/postgres/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -100,12 +100,13 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func NewPostgresEventTypesTable(db *sql.DB) (tables.EventTypes, error) { - s := &eventTypeStatements{} +func createEventTypesTable(db *sql.DB) error { _, err := db.Exec(eventTypesSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { + s := &eventTypeStatements{} return s, shared.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 0cf0bd22..88c82083 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "fmt" + "sort" "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" @@ -88,6 +89,16 @@ const bulkSelectStateEventByIDSQL = "" + " WHERE event_id = ANY($1)" + " ORDER BY event_type_nid, event_state_key_nid ASC" +// Bulk look up of events by event NID, optionally filtering by the event type +// or event state key NIDs if provided. (The CARDINALITY check will return true +// if the provided arrays are empty, ergo no filtering). +const bulkSelectStateEventByNIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_nid = ANY($1)" + + " AND (CARDINALITY($2::bigint[]) = 0 OR event_type_nid = ANY($2))" + + " AND (CARDINALITY($3::bigint[]) = 0 OR event_state_key_nid = ANY($3))" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + const bulkSelectStateAtEventByIDSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + " WHERE event_id = ANY($1)" @@ -127,6 +138,7 @@ type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateEventByNIDStmt *sql.Stmt bulkSelectStateAtEventByIDStmt *sql.Stmt updateEventStateStmt *sql.Stmt selectEventSentToOutputStmt *sql.Stmt @@ -140,17 +152,19 @@ type eventStatements struct { selectRoomNIDsForEventNIDsStmt *sql.Stmt } -func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { - s := &eventStatements{} +func createEventsTable(db *sql.DB) error { _, err := db.Exec(eventsSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareEventsTable(db *sql.DB) (tables.Events, error) { + s := &eventStatements{} return s, shared.StatementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, + {&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, {&s.updateEventStateStmt, updateEventStateSQL}, {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, @@ -238,6 +252,42 @@ func (s *eventStatements) BulkSelectStateEventByID( return results, nil } +// bulkSelectStateEventByNID lookups a list of state events by event NID. +// If any of the requested events are missing from the database it returns a types.MissingEventError +func (s *eventStatements) BulkSelectStateEventByNID( + ctx context.Context, eventNIDs []types.EventNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntry, error) { + tuples := stateKeyTupleSorter(stateKeyTuples) + sort.Sort(tuples) + eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") + // We know that we will only get as many results as event IDs + // because of the unique constraint on event IDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than IDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + ); err != nil { + return nil, err + } + } + if err = rows.Err(); err != nil { + return nil, err + } + return results[:i], nil +} + // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go index bb719516..0a2183e2 100644 --- a/roomserver/storage/postgres/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -82,12 +82,13 @@ type inviteStatements struct { updateInviteRetiredStmt *sql.Stmt } -func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { - s := &inviteStatements{} +func createInvitesTable(db *sql.DB) error { _, err := db.Exec(inviteSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { + s := &inviteStatements{} return s, shared.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index e392a4fb..3466da6d 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -139,12 +139,13 @@ type membershipStatements struct { updateMembershipForgetRoomStmt *sql.Stmt } -func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { - s := &membershipStatements{} +func createMembershipTable(db *sql.DB) error { _, err := db.Exec(membershipSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { + s := &membershipStatements{} return s, shared.StatementList{ {&s.insertMembershipStmt, insertMembershipSQL}, @@ -162,11 +163,6 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { }.Prepare(db) } -func (s *membershipStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(membershipSchema) - return err -} - func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go index 1a4ba673..4a93c3d6 100644 --- a/roomserver/storage/postgres/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -65,12 +65,13 @@ type previousEventStatements struct { selectPreviousEventExistsStmt *sql.Stmt } -func NewPostgresPreviousEventsTable(db *sql.DB) (tables.PreviousEvents, error) { - s := &previousEventStatements{} +func createPrevEventsTable(db *sql.DB) error { _, err := db.Exec(previousEventSchema) - if err != nil { - return nil, err - } + return err +} + +func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { + s := &previousEventStatements{} return s, shared.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index 440ae784..c180576e 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -50,12 +50,14 @@ type publishedStatements struct { selectPublishedStmt *sql.Stmt } -func NewPostgresPublishedTable(db *sql.DB) (tables.Published, error) { - s := &publishedStatements{} +func createPublishedTable(db *sql.DB) error { _, err := db.Exec(publishedSchema) - if err != nil { - return nil, err - } + return err +} + +func preparePublishedTable(db *sql.DB) (tables.Published, error) { + s := &publishedStatements{} + return s, shared.StatementList{ {&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL}, diff --git a/roomserver/storage/postgres/redactions_table.go b/roomserver/storage/postgres/redactions_table.go index 42aba598..3741d5f6 100644 --- a/roomserver/storage/postgres/redactions_table.go +++ b/roomserver/storage/postgres/redactions_table.go @@ -60,12 +60,13 @@ type redactionStatements struct { markRedactionValidatedStmt *sql.Stmt } -func NewPostgresRedactionsTable(db *sql.DB) (tables.Redactions, error) { - s := &redactionStatements{} +func createRedactionsTable(db *sql.DB) error { _, err := db.Exec(redactionsSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareRedactionsTable(db *sql.DB) (tables.Redactions, error) { + s := &redactionStatements{} return s, shared.StatementList{ {&s.insertRedactionStmt, insertRedactionSQL}, diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index b603a673..c808813e 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -62,12 +62,14 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func NewPostgresRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { - s := &roomAliasesStatements{} +func createRoomAliasesTable(db *sql.DB) error { _, err := db.Exec(roomAliasesSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { + s := &roomAliasesStatements{} + return s, shared.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 637680bd..f2b39fe5 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -96,12 +96,14 @@ type roomStatements struct { bulkSelectRoomNIDsStmt *sql.Stmt } -func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { - s := &roomStatements{} +func createRoomsTable(db *sql.DB) error { _, err := db.Exec(roomsSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { + s := &roomStatements{} + return s, shared.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index d618686f..4523d18b 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -41,141 +41,88 @@ const stateDataSchema = ` -- which in turn makes it easier to merge state data blocks. CREATE SEQUENCE IF NOT EXISTS roomserver_state_block_nid_seq; CREATE TABLE IF NOT EXISTS roomserver_state_block ( - -- Local numeric ID for this state data. - state_block_nid bigint NOT NULL, - event_type_nid bigint NOT NULL, - event_state_key_nid bigint NOT NULL, - event_nid bigint NOT NULL, - UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) + -- The state snapshot NID that identifies this snapshot. + state_block_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_block_nid_seq'), + -- The hash of the state block, which is used to enforce uniqueness. The hash is + -- generated in Dendrite and passed through to the database, as a btree index over + -- this column is cheap and fits within the maximum index size. + state_block_hash BYTEA UNIQUE, + -- The event NIDs contained within the state block. + event_nids bigint[] NOT NULL ); ` +// Insert a new state block. If we conflict on the hash column then +// we must perform an update so that the RETURNING statement returns the +// ID of the row that we conflicted with, so that we can then refer to +// the original block. const insertStateDataSQL = "" + - "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + - " VALUES ($1, $2, $3, $4)" + "INSERT INTO roomserver_state_block (state_block_hash, event_nids)" + + " VALUES ($1, $2)" + + " ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2" + + " RETURNING state_block_nid" -const selectNextStateBlockNIDSQL = "" + - "SELECT nextval('roomserver_state_block_nid_seq')" - -// Bulk state lookup by numeric state block ID. -// Sort by the state_block_nid, event_type_nid, event_state_key_nid -// This means that all the entries for a given state_block_nid will appear -// together in the list and those entries will sorted by event_type_nid -// and event_state_key_nid. This property makes it easier to merge two -// state data blocks together. const bulkSelectStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" - -// Bulk state lookup by numeric state block ID. -// Filters the rows in each block to the requested types and state keys. -// We would like to restrict to particular type state key pairs but we are -// restricted by the query language to pull the cross product of a list -// of types and a list state_keys. So we have to filter the result in the -// application to restrict it to the list of event types and state keys we -// actually wanted. -const bulkSelectFilteredStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" + - " AND event_type_nid = ANY($2) AND event_state_key_nid = ANY($3)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + "SELECT state_block_nid, event_nids" + + " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" type stateBlockStatements struct { - insertStateDataStmt *sql.Stmt - selectNextStateBlockNIDStmt *sql.Stmt - bulkSelectStateBlockEntriesStmt *sql.Stmt - bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt + insertStateDataStmt *sql.Stmt + bulkSelectStateBlockEntriesStmt *sql.Stmt } -func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) { - s := &stateBlockStatements{} +func createStateBlockTable(db *sql.DB) error { _, err := db.Exec(stateDataSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { + s := &stateBlockStatements{} return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, - {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, - {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, }.Prepare(db) } func (s *stateBlockStatements) BulkInsertStateData( ctx context.Context, txn *sql.Tx, - entries []types.StateEntry, -) (types.StateBlockNID, error) { - stateBlockNID, err := s.selectNextStateBlockNID(ctx) - if err != nil { - return 0, err - } - for _, entry := range entries { - _, err := s.insertStateDataStmt.ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) - if err != nil { - return 0, err - } + entries types.StateEntries, +) (id types.StateBlockNID, err error) { + entries = entries[:util.SortAndUnique(entries)] + var nids types.EventNIDs + for _, e := range entries { + nids = append(nids, e.EventNID) } - return stateBlockNID, nil -} - -func (s *stateBlockStatements) selectNextStateBlockNID( - ctx context.Context, -) (types.StateBlockNID, error) { - var stateBlockNID int64 - err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID) - return types.StateBlockNID(stateBlockNID), err + err = s.insertStateDataStmt.QueryRowContext( + ctx, nids.Hash(), eventNIDsAsArray(nids), + ).Scan(&id) + return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - nids := make([]int64, len(stateBlockNIDs)) - for i := range stateBlockNIDs { - nids[i] = int64(stateBlockNIDs[i]) - } - rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids)) + ctx context.Context, stateBlockNIDs types.StateBlockNIDs, +) ([][]types.EventNID, error) { + rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") - results := make([]types.StateEntryList, len(stateBlockNIDs)) - // current is a pointer to the StateEntryList to append the state entries to. - var current *types.StateEntryList + results := make([][]types.EventNID, len(stateBlockNIDs)) i := 0 - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err = rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { + for ; rows.Next(); i++ { + var stateBlockNID types.StateBlockNID + var result pq.Int64Array + if err = rows.Scan(&stateBlockNID, &result); err != nil { return nil, err } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we start appending to the next entry in the list. - current = &results[i] - current.StateBlockNID = types.StateBlockNID(stateBlockNID) - i++ + r := []types.EventNID{} + for _, e := range result { + r = append(r, types.EventNID(e)) } - current.StateEntries = append(current.StateEntries, entry) + results[i] = r } if err = rows.Err(); err != nil { return nil, err @@ -186,71 +133,6 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( return results, err } -func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) - // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. - sort.Sort(tuples) - - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext( - ctx, - stateBlockNIDsAsArray(stateBlockNIDs), - eventTypeNIDArray, - eventStateKeyNIDArray, - ) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") - - var results []types.StateEntryList - var current types.StateEntryList - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { - return nil, err - } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - - // We can use binary search here because we sorted the tuples earlier - if !tuples.contains(entry.StateKeyTuple) { - // The select will return the cross product of types and state keys. - // So we need to check if type of the entry is in the list. - continue - } - - if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we append the current entry to the results and start adding to a new one. - // The first time through the loop current will be empty. - if current.StateEntries != nil { - results = append(results, current) - } - current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} - } - current.StateEntries = append(current.StateEntries, entry) - } - // Add the last entry to the list if it is not empty. - if current.StateEntries != nil { - results = append(results, current) - } - return results, rows.Err() -} - func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array { nids := make([]int64, len(stateBlockNIDs)) for i := range stateBlockNIDs { diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 63175955..15e14e2e 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" ) const stateSnapshotSchema = ` @@ -40,19 +41,29 @@ const stateSnapshotSchema = ` -- the full state under single state_block_nid. CREATE SEQUENCE IF NOT EXISTS roomserver_state_snapshot_nid_seq; CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( - -- Local numeric ID for the state. - state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_seq'), - -- Local numeric ID of the room this state is for. - -- Unused in normal operation, but useful for background work or ad-hoc debugging. - room_nid bigint NOT NULL, - -- List of state_block_nids, stored sorted by state_block_nid. - state_block_nids bigint[] NOT NULL + -- The state snapshot NID that identifies this snapshot. + state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_seq'), + -- The hash of the state snapshot, which is used to enforce uniqueness. The hash is + -- generated in Dendrite and passed through to the database, as a btree index over + -- this column is cheap and fits within the maximum index size. + state_snapshot_hash BYTEA UNIQUE, + -- The room NID that the snapshot belongs to. + room_nid bigint NOT NULL, + -- The state blocks contained within this snapshot. + state_block_nids bigint[] NOT NULL ); ` +// Insert a new state snapshot. If we conflict on the hash column then +// we must perform an update so that the RETURNING statement returns the +// ID of the row that we conflicted with, so that we can then refer to +// the original snapshot. const insertStateSQL = "" + - "INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)" + - " VALUES ($1, $2)" + + "INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids)" + + " VALUES ($1, $2, $3)" + + " ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2" + + // Performing an update, above, ensures that the RETURNING statement + // below will always return a valid state snapshot ID " RETURNING state_snapshot_nid" // Bulk state data NID lookup. @@ -67,12 +78,13 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { - s := &stateSnapshotStatements{} +func createStateSnapshotTable(db *sql.DB) error { _, err := db.Exec(stateSnapshotSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { + s := &stateSnapshotStatements{} return s, shared.StatementList{ {&s.insertStateStmt, insertStateSQL}, @@ -81,13 +93,15 @@ func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { } func (s *stateSnapshotStatements) InsertState( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs, ) (stateNID types.StateSnapshotNID, err error) { - nids := make([]int64, len(stateBlockNIDs)) - for i := range stateBlockNIDs { - nids[i] = int64(stateBlockNIDs[i]) + nids = nids[:util.SortAndUnique(nids)] + var id int64 + err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id) + if err != nil { + return 0, err } - err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) + stateNID = types.StateSnapshotNID(id) return } diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index bb3f841d..863a1593 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -17,6 +17,7 @@ package postgres import ( "database/sql" + "fmt" // Import the postgres database driver. _ "github.com/lib/pq" @@ -39,20 +40,25 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) var db *sql.DB var err error if db, err = sqlutil.Open(dbProperties); err != nil { - return nil, err + return nil, fmt.Errorf("sqlutil.Open: %w", err) } - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - ms := membershipStatements{} - if err := ms.execSchema(db); err != nil { + // Create the tables. + if err := d.create(db); err != nil { return nil, err } + + // Then execute the migrations. By this point the tables are created with the latest + // schemas. m := sqlutil.NewMigrations() deltas.LoadAddForgottenColumn(m) + deltas.LoadStateBlocksRefactor(m) if err := m.RunDeltas(db, dbProperties); err != nil { return nil, err } + + // Then prepare the statements. Now that the migrations have run, any columns referred + // to in the database code should now exist. if err := d.prepare(db, cache); err != nil { return nil, err } @@ -60,61 +66,107 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) return &d, nil } -// nolint: gocyclo -func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) (err error) { - eventStateKeys, err := NewPostgresEventStateKeysTable(db) +func (d *Database) create(db *sql.DB) error { + if err := createEventStateKeysTable(db); err != nil { + return err + } + if err := createEventTypesTable(db); err != nil { + return err + } + if err := createEventJSONTable(db); err != nil { + return err + } + if err := createEventsTable(db); err != nil { + return err + } + if err := createRoomsTable(db); err != nil { + return err + } + if err := createTransactionsTable(db); err != nil { + return err + } + if err := createStateBlockTable(db); err != nil { + return err + } + if err := createStateSnapshotTable(db); err != nil { + return err + } + if err := createPrevEventsTable(db); err != nil { + return err + } + if err := createRoomAliasesTable(db); err != nil { + return err + } + if err := createInvitesTable(db); err != nil { + return err + } + if err := createMembershipTable(db); err != nil { + return err + } + if err := createPublishedTable(db); err != nil { + return err + } + if err := createRedactionsTable(db); err != nil { + return err + } + + return nil +} + +func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { + eventStateKeys, err := prepareEventStateKeysTable(db) if err != nil { return err } - eventTypes, err := NewPostgresEventTypesTable(db) + eventTypes, err := prepareEventTypesTable(db) if err != nil { return err } - eventJSON, err := NewPostgresEventJSONTable(db) + eventJSON, err := prepareEventJSONTable(db) if err != nil { return err } - events, err := NewPostgresEventsTable(db) + events, err := prepareEventsTable(db) if err != nil { return err } - rooms, err := NewPostgresRoomsTable(db) + rooms, err := prepareRoomsTable(db) if err != nil { return err } - transactions, err := NewPostgresTransactionsTable(db) + transactions, err := prepareTransactionsTable(db) if err != nil { return err } - stateBlock, err := NewPostgresStateBlockTable(db) + stateBlock, err := prepareStateBlockTable(db) if err != nil { return err } - stateSnapshot, err := NewPostgresStateSnapshotTable(db) + stateSnapshot, err := prepareStateSnapshotTable(db) if err != nil { return err } - roomAliases, err := NewPostgresRoomAliasesTable(db) + prevEvents, err := preparePrevEventsTable(db) if err != nil { return err } - prevEvents, err := NewPostgresPreviousEventsTable(db) + roomAliases, err := prepareRoomAliasesTable(db) if err != nil { return err } - invites, err := NewPostgresInvitesTable(db) + invites, err := prepareInvitesTable(db) if err != nil { return err } - membership, err := NewPostgresMembershipTable(db) + membership, err := prepareMembershipTable(db) if err != nil { return err } - published, err := NewPostgresPublishedTable(db) + published, err := preparePublishedTable(db) if err != nil { return err } - redactions, err := NewPostgresRedactionsTable(db) + redactions, err := prepareRedactionsTable(db) if err != nil { return err } diff --git a/roomserver/storage/postgres/transactions_table.go b/roomserver/storage/postgres/transactions_table.go index 5e59ae16..cada0d8a 100644 --- a/roomserver/storage/postgres/transactions_table.go +++ b/roomserver/storage/postgres/transactions_table.go @@ -54,12 +54,13 @@ type transactionStatements struct { selectTransactionEventIDStmt *sql.Stmt } -func NewPostgresTransactionsTable(db *sql.DB) (tables.Transactions, error) { - s := &transactionStatements{} +func createTransactionsTable(db *sql.DB) error { _, err := db.Exec(transactionsSchema) - if err != nil { - return nil, err - } + return err +} + +func prepareTransactionsTable(db *sql.DB) (tables.Transactions, error) { + s := &transactionStatements{} return s, shared.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL}, diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 24b48772..096d5d7a 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -118,9 +118,24 @@ func (d *Database) StateEntriesForTuples( stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { - return d.StateBlockTable.BulkSelectFilteredStateBlockEntries( - ctx, stateBlockNIDs, stateKeyTuples, + entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( + ctx, stateBlockNIDs, ) + if err != nil { + return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) + } + lists := []types.StateEntryList{} + for i, entry := range entries { + entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples) + if err != nil { + return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) + } + lists = append(lists, types.StateEntryList{ + StateBlockNID: stateBlockNIDs[i], + StateEntries: entries, + }) + } + return lists, nil } func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { @@ -141,8 +156,28 @@ func (d *Database) AddState( stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { + if len(stateBlockNIDs) > 0 { + // Check to see if the event already appears in any of the existing state + // blocks. If it does then we should not add it again, as this will just + // result in excess state blocks and snapshots. + // TODO: Investigate why this is happening - probably input_events.go! + blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) + if berr != nil { + return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr) + } + for i := len(state) - 1; i >= 0; i-- { + for _, events := range blocks { + for _, event := range events { + if state[i].EventNID == event { + state = append(state[:i], state[i+1:]...) + } + } + } + } + } err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if len(state) > 0 { + // If there's any state left to add then let's add new blocks. var stateBlockNID types.StateBlockNID stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state) if err != nil { @@ -237,7 +272,24 @@ func (d *Database) StateBlockNIDs( func (d *Database) StateEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { - return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) + entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( + ctx, stateBlockNIDs, + ) + if err != nil { + return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) + } + lists := make([]types.StateEntryList, 0, len(entries)) + for i, entry := range entries { + eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil) + if err != nil { + return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) + } + lists = append(lists, types.StateEntryList{ + StateBlockNID: stateBlockNIDs[i], + StateEntries: eventNIDs, + }) + } + return lists, nil } func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { diff --git a/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go b/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go index 33fe9e2a..d08ab02d 100644 --- a/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go +++ b/roomserver/storage/sqlite3/deltas/20201028212440_add_forgotten_column.go @@ -24,6 +24,7 @@ import ( func LoadFromGoose() { goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) + goose.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor) } func LoadAddForgottenColumn(m *sqlutil.Migrations) { diff --git a/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go new file mode 100644 index 00000000..56158545 --- /dev/null +++ b/roomserver/storage/sqlite3/deltas/2021041615092700_state_blocks_refactor.go @@ -0,0 +1,168 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +func LoadStateBlocksRefactor(m *sqlutil.Migrations) { + m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor) +} + +func UpStateBlocksRefactor(tx *sql.Tx) error { + logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!") + defer logrus.Warn("State storage upgrade complete") + + var maxsnapshotid int + var maxblockid int + if err := tx.QueryRow(`SELECT IFNULL(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil { + return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) + } + if err := tx.QueryRow(`SELECT IFNULL(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil { + return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) + } + maxsnapshotid++ + maxblockid++ + + if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + if _, err := tx.Exec(`ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + _, err := tx.Exec(` + CREATE TABLE IF NOT EXISTS roomserver_state_block ( + state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT, + state_block_hash BLOB UNIQUE, + event_nids TEXT NOT NULL DEFAULT '[]' + ); + `) + if err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + _, err = tx.Exec(` + CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( + state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, + state_snapshot_hash BLOB UNIQUE, + room_nid INTEGER NOT NULL, + state_block_nids TEXT NOT NULL DEFAULT '[]' + ); + `) + if err != nil { + return fmt.Errorf("tx.Exec: %w", err) + } + snapshotrows, err := tx.Query(`SELECT state_snapshot_nid, room_nid, state_block_nids FROM _roomserver_state_snapshots;`) + if err != nil { + return fmt.Errorf("tx.Query: %w", err) + } + defer internal.CloseAndLogIfError(context.TODO(), snapshotrows, "rows.close() failed") + for snapshotrows.Next() { + var snapshot types.StateSnapshotNID + var room types.RoomNID + var jsonblocks string + var blocks []types.StateBlockNID + if err = snapshotrows.Scan(&snapshot, &room, &jsonblocks); err != nil { + return fmt.Errorf("rows.Scan: %w", err) + } + if err = json.Unmarshal([]byte(jsonblocks), &blocks); err != nil { + return fmt.Errorf("json.Unmarshal: %w", err) + } + + var newblocks types.StateBlockNIDs + for _, block := range blocks { + if err = func() error { + blockrows, berr := tx.Query(`SELECT event_nid FROM _roomserver_state_block WHERE state_block_nid = $1`, block) + if berr != nil { + return fmt.Errorf("tx.Query (event nids from old block): %w", berr) + } + defer internal.CloseAndLogIfError(context.TODO(), blockrows, "rows.close() failed") + events := types.EventNIDs{} + for blockrows.Next() { + var event types.EventNID + if err = blockrows.Scan(&event); err != nil { + return fmt.Errorf("rows.Scan: %w", err) + } + events = append(events, event) + } + events = events[:util.SortAndUnique(events)] + eventjson, eerr := json.Marshal(events) + if eerr != nil { + return fmt.Errorf("json.Marshal: %w", eerr) + } + + var blocknid types.StateBlockNID + err = tx.QueryRow(` + INSERT INTO roomserver_state_block (state_block_nid, state_block_hash, event_nids) + VALUES ($1, $2, $3) + ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$3 + RETURNING state_block_nid + `, maxblockid, events.Hash(), eventjson).Scan(&blocknid) + if err != nil { + return fmt.Errorf("tx.QueryRow.Scan (insert new block): %w", err) + } + maxblockid++ + newblocks = append(newblocks, blocknid) + return nil + }(); err != nil { + return err + } + + newblocksjson, jerr := json.Marshal(newblocks) + if jerr != nil { + return fmt.Errorf("json.Marshal (new blocks): %w", jerr) + } + var newsnapshot types.StateSnapshotNID + err = tx.QueryRow(` + INSERT INTO roomserver_state_snapshots (state_snapshot_nid, state_snapshot_hash, room_nid, state_block_nids) + VALUES ($1, $2, $3, $4) + ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$3 + RETURNING state_snapshot_nid + `, maxsnapshotid, newblocks.Hash(), room, newblocksjson).Scan(&newsnapshot) + if err != nil { + return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err) + } + maxsnapshotid++ + if _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$2`, newsnapshot, snapshot, maxsnapshotid); err != nil { + return fmt.Errorf("tx.Exec (update events): %w", err) + } + if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$2`, newsnapshot, snapshot, maxsnapshotid); err != nil { + return fmt.Errorf("tx.Exec (update rooms): %w", err) + } + } + } + + if _, err = tx.Exec(`DROP TABLE _roomserver_state_snapshots;`); err != nil { + return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err) + } + if _, err = tx.Exec(`DROP TABLE _roomserver_state_block;`); err != nil { + return fmt.Errorf("tx.Exec (delete old block table): %w", err) + } + + return nil +} + +func DownStateBlocksRefactor(tx *sql.Tx) error { + panic("Downgrading state storage is not supported") +} diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index 3cd44b1d..29d54b83 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -53,14 +53,16 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { +func createEventJSONTable(db *sql.DB) error { + _, err := db.Exec(eventJSONSchema) + return err +} + +func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { s := &eventJSONStatements{ db: db, } - _, err := db.Exec(eventJSONSchema) - if err != nil { - return nil, err - } + return s, shared.StatementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index 345df8c6..d430e553 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -70,14 +70,16 @@ type eventStateKeyStatements struct { bulkSelectEventStateKeyStmt *sql.Stmt } -func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { +func createEventStateKeysTable(db *sql.DB) error { + _, err := db.Exec(eventStateKeysSchema) + return err +} + +func prepareEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { s := &eventStateKeyStatements{ db: db, } - _, err := db.Exec(eventStateKeysSchema) - if err != nil { - return nil, err - } + return s, shared.StatementList{ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 26e2bf84..694f4e21 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -85,14 +85,15 @@ type eventTypeStatements struct { bulkSelectEventTypeNIDStmt *sql.Stmt } -func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { +func createEventTypesTable(db *sql.DB) error { + _, err := db.Exec(eventTypesSchema) + return err +} + +func prepareEventTypesTable(db *sql.DB) (tables.EventTypes, error) { s := &eventTypeStatements{ db: db, } - _, err := db.Exec(eventTypesSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 53269657..e964770d 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "encoding/json" "fmt" + "sort" "strings" "github.com/matrix-org/dendrite/internal" @@ -63,6 +64,11 @@ const bulkSelectStateEventByIDSQL = "" + " WHERE event_id IN ($1)" + " ORDER BY event_type_nid, event_state_key_nid ASC" +const bulkSelectStateEventByNIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_nid IN ($1)" + // Rest of query is built by BulkSelectStateEventByNID + const bulkSelectStateAtEventByIDSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + " WHERE event_id IN ($1)" @@ -115,14 +121,15 @@ type eventStatements struct { //selectRoomNIDsForEventNIDsStmt *sql.Stmt } -func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { +func createEventsTable(db *sql.DB) error { + _, err := db.Exec(eventsSchema) + return err +} + +func prepareEventsTable(db *sql.DB) (tables.Events, error) { s := &eventStatements{ db: db, } - _, err := db.Exec(eventsSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertEventStmt, insertEventSQL}, @@ -232,6 +239,61 @@ func (s *eventStatements) BulkSelectStateEventByID( return results, err } +// bulkSelectStateEventByID lookups a list of state events by event ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError +func (s *eventStatements) BulkSelectStateEventByNID( + ctx context.Context, eventNIDs []types.EventNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntry, error) { + tuples := stateKeyTupleSorter(stateKeyTuples) + sort.Sort(tuples) + eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray)) + selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1) + for _, v := range eventNIDs { + params = append(params, v) + } + if len(eventTypeNIDArray) > 0 { + selectOrig += " AND event_type_nid IN " + sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(params)) + for _, v := range eventTypeNIDArray { + params = append(params, v) + } + } + if len(eventStateKeyNIDArray) > 0 { + selectOrig += " AND event_state_key_nid IN " + sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(params)) + for _, v := range eventStateKeyNIDArray { + params = append(params, v) + } + } + selectOrig += " ORDER BY event_type_nid, event_state_key_nid ASC" + selectStmt, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, fmt.Errorf("s.db.Prepare: %w", err) + } + rows, err := selectStmt.QueryContext(ctx, params...) + if err != nil { + return nil, fmt.Errorf("selectStmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") + // We know that we will only get as many results as event IDs + // because of the unique constraint on event IDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than IDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + ); err != nil { + return nil, err + } + } + return results[:i], err +} + // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index 327be6a0..e1aa1ebd 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -70,14 +70,15 @@ type inviteStatements struct { selectInvitesAboutToRetireStmt *sql.Stmt } -func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { +func createInvitesTable(db *sql.DB) error { + _, err := db.Exec(inviteSchema) + return err +} + +func prepareInvitesTable(db *sql.DB) (tables.Invites, error) { s := &inviteStatements{ db: db, } - _, err := db.Exec(inviteSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertInviteEventStmt, insertInviteEventSQL}, diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index d716ced0..d9fe32cf 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -115,7 +115,12 @@ type membershipStatements struct { updateMembershipForgetRoomStmt *sql.Stmt } -func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { +func createMembershipTable(db *sql.DB) error { + _, err := db.Exec(membershipSchema) + return err +} + +func prepareMembershipTable(db *sql.DB) (tables.Membership, error) { s := &membershipStatements{ db: db, } @@ -135,11 +140,6 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { }.Prepare(db) } -func (s *membershipStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(membershipSchema) - return err -} - func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index aaee6273..3cb52767 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -71,14 +71,15 @@ type previousEventStatements struct { selectPreviousEventExistsStmt *sql.Stmt } -func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { +func createPrevEventsTable(db *sql.DB) error { + _, err := db.Exec(previousEventSchema) + return err +} + +func preparePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { s := &previousEventStatements{ db: db, } - _, err := db.Exec(previousEventSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index dcf6f697..6d9d9135 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -50,14 +50,16 @@ type publishedStatements struct { selectPublishedStmt *sql.Stmt } -func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { +func createPublishedTable(db *sql.DB) error { + _, err := db.Exec(publishedSchema) + return err +} + +func preparePublishedTable(db *sql.DB) (tables.Published, error) { s := &publishedStatements{ db: db, } - _, err := db.Exec(publishedSchema) - if err != nil { - return nil, err - } + return s, shared.StatementList{ {&s.upsertPublishedStmt, upsertPublishedSQL}, {&s.selectAllPublishedStmt, selectAllPublishedSQL}, diff --git a/roomserver/storage/sqlite3/redactions_table.go b/roomserver/storage/sqlite3/redactions_table.go index e6471486..b3498182 100644 --- a/roomserver/storage/sqlite3/redactions_table.go +++ b/roomserver/storage/sqlite3/redactions_table.go @@ -59,14 +59,15 @@ type redactionStatements struct { markRedactionValidatedStmt *sql.Stmt } -func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) { +func createRedactionsTable(db *sql.DB) error { + _, err := db.Exec(redactionsSchema) + return err +} + +func prepareRedactionsTable(db *sql.DB) (tables.Redactions, error) { s := &redactionStatements{ db: db, } - _, err := db.Exec(redactionsSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertRedactionStmt, insertRedactionSQL}, diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index f053e398..5215fa6f 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -64,14 +64,16 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { +func createRoomAliasesTable(db *sql.DB) error { + _, err := db.Exec(roomAliasesSchema) + return err +} + +func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { s := &roomAliasesStatements{ db: db, } - _, err := db.Exec(roomAliasesSchema) - if err != nil { - return nil, err - } + return s, shared.StatementList{ {&s.insertRoomAliasStmt, insertRoomAliasSQL}, {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index fe8e601f..534a870c 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -86,14 +86,16 @@ type roomStatements struct { selectRoomIDsStmt *sql.Stmt } -func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { +func createRoomsTable(db *sql.DB) error { + _, err := db.Exec(roomsSchema) + return err +} + +func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) { s := &roomStatements{ db: db, } - _, err := db.Exec(roomsSchema) - if err != nil { - return nil, err - } + return s, shared.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 2c544f2b..cfb2a49e 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -18,6 +18,7 @@ package sqlite3 import ( "context" "database/sql" + "encoding/json" "fmt" "sort" "strings" @@ -32,228 +33,113 @@ import ( const stateDataSchema = ` CREATE TABLE IF NOT EXISTS roomserver_state_block ( - state_block_nid INTEGER NOT NULL, - event_type_nid INTEGER NOT NULL, - event_state_key_nid INTEGER NOT NULL, - event_nid INTEGER NOT NULL, - UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) + -- The state snapshot NID that identifies this snapshot. + state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT, + -- The hash of the state block, which is used to enforce uniqueness. The hash is + -- generated in Dendrite and passed through to the database, as a btree index over + -- this column is cheap and fits within the maximum index size. + state_block_hash BLOB UNIQUE, + -- The event NIDs contained within the state block, encoded as JSON. + event_nids TEXT NOT NULL DEFAULT '[]' ); ` -const insertStateDataSQL = "" + - "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + - " VALUES ($1, $2, $3, $4)" - -const selectNextStateBlockNIDSQL = ` -SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block +// Insert a new state block. If we conflict on the hash column then +// we must perform an update so that the RETURNING statement returns the +// ID of the row that we conflicted with, so that we can then refer to +// the original block. +const insertStateDataSQL = ` + INSERT INTO roomserver_state_block (state_block_hash, event_nids) + VALUES ($1, $2) + ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2 + RETURNING state_block_nid ` -// Bulk state lookup by numeric state block ID. -// Sort by the state_block_nid, event_type_nid, event_state_key_nid -// This means that all the entries for a given state_block_nid will appear -// together in the list and those entries will sorted by event_type_nid -// and event_state_key_nid. This property makes it easier to merge two -// state data blocks together. const bulkSelectStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" - -// Bulk state lookup by numeric state block ID. -// Filters the rows in each block to the requested types and state keys. -// We would like to restrict to particular type state key pairs but we are -// restricted by the query language to pull the cross product of a list -// of types and a list state_keys. So we have to filter the result in the -// application to restrict it to the list of event types and state keys we -// actually wanted. -const bulkSelectFilteredStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + "SELECT state_block_nid, event_nids" + + " FROM roomserver_state_block WHERE state_block_nid IN ($1)" type stateBlockStatements struct { - db *sql.DB - insertStateDataStmt *sql.Stmt - selectNextStateBlockNIDStmt *sql.Stmt - bulkSelectStateBlockEntriesStmt *sql.Stmt - bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt + db *sql.DB + insertStateDataStmt *sql.Stmt + bulkSelectStateBlockEntriesStmt *sql.Stmt +} + +func createStateBlockTable(db *sql.DB) error { + _, err := db.Exec(stateDataSchema) + return err } -func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { +func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { s := &stateBlockStatements{ db: db, } - _, err := db.Exec(stateDataSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, - {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, - {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, }.Prepare(db) } func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, txn *sql.Tx, - entries []types.StateEntry, -) (types.StateBlockNID, error) { - if len(entries) == 0 { - return 0, nil - } - var stateBlockNID types.StateBlockNID - err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + ctx context.Context, + txn *sql.Tx, + entries types.StateEntries, +) (id types.StateBlockNID, err error) { + entries = entries[:util.SortAndUnique(entries)] + var nids types.EventNIDs + for _, e := range entries { + nids = append(nids, e.EventNID) + } + js, err := json.Marshal(nids) if err != nil { - return 0, err - } - for _, entry := range entries { - _, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) - if err != nil { - return 0, err - } + return 0, fmt.Errorf("json.Marshal: %w", err) } - return stateBlockNID, err + err = s.insertStateDataStmt.QueryRowContext( + ctx, nids.Hash(), js, + ).Scan(&id) + return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - nids := make([]interface{}, len(stateBlockNIDs)) - for k, v := range stateBlockNIDs { - nids[k] = v + ctx context.Context, stateBlockNIDs types.StateBlockNIDs, +) ([][]types.EventNID, error) { + intfs := make([]interface{}, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + intfs[i] = int64(stateBlockNIDs[i]) } - selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(intfs)), 1) selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } - rows, err := selectStmt.QueryContext(ctx, nids...) + rows, err := selectStmt.QueryContext(ctx, intfs...) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") - results := make([]types.StateEntryList, len(stateBlockNIDs)) - // current is a pointer to the StateEntryList to append the state entries to. - var current *types.StateEntryList + results := make([][]types.EventNID, len(stateBlockNIDs)) i := 0 - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { + for ; rows.Next(); i++ { + var stateBlockNID types.StateBlockNID + var result json.RawMessage + if err = rows.Scan(&stateBlockNID, &result); err != nil { return nil, err } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we start appending to the next entry in the list. - current = &results[i] - current.StateBlockNID = types.StateBlockNID(stateBlockNID) - i++ + r := []types.EventNID{} + if err = json.Unmarshal(result, &r); err != nil { + return nil, fmt.Errorf("json.Unmarshal: %w", err) } - current.StateEntries = append(current.StateEntries, entry) - } - if i != len(nids) { - return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids)) + results[i] = r } - return results, nil -} - -func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) - // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. - sort.Sort(tuples) - - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) - - var params []interface{} - for _, val := range stateBlockNIDs { - params = append(params, int64(val)) - } - for _, val := range eventTypeNIDArray { - params = append(params, val) - } - for _, val := range eventStateKeyNIDArray { - params = append(params, val) - } - - rows, err := s.db.QueryContext( - ctx, - sqlStatement, - params..., - ) - if err != nil { + if err = rows.Err(); err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") - - var results []types.StateEntryList - var current types.StateEntryList - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { - return nil, err - } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - - // We can use binary search here because we sorted the tuples earlier - if !tuples.contains(entry.StateKeyTuple) { - // The select will return the cross product of types and state keys. - // So we need to check if type of the entry is in the list. - continue - } - - if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we append the current entry to the results and start adding to a new one. - // The first time through the loop current will be empty. - if current.StateEntries != nil { - results = append(results, current) - } - current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} - } - current.StateEntries = append(current.StateEntries, entry) - } - // Add the last entry to the list if it is not empty. - if current.StateEntries != nil { - results = append(results, current) + if i != len(stateBlockNIDs) { + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateBlockNIDs)) } - return results, nil + return results, err } type stateKeyTupleSorter []types.StateKeyTuple diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index bf49f62c..95cae99e 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -27,19 +27,34 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" ) const stateSnapshotSchema = ` CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( + -- The state snapshot NID that identifies this snapshot. state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, + -- The hash of the state snapshot, which is used to enforce uniqueness. The hash is + -- generated in Dendrite and passed through to the database, as a btree index over + -- this column is cheap and fits within the maximum index size. + state_snapshot_hash BLOB UNIQUE, + -- The room NID that the snapshot belongs to. room_nid INTEGER NOT NULL, + -- The state blocks contained within this snapshot, encoded as JSON. state_block_nids TEXT NOT NULL DEFAULT '[]' ); ` +// Insert a new state snapshot. If we conflict on the hash column then +// we must perform an update so that the RETURNING statement returns the +// ID of the row that we conflicted with, so that we can then refer to +// the original snapshot. const insertStateSQL = ` - INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) - VALUES ($1, $2);` + INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids) + VALUES ($1, $2, $3) + ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2 + RETURNING state_snapshot_nid +` // Bulk state data NID lookup. // Sorting by state_snapshot_nid means we can use binary search over the result @@ -54,14 +69,15 @@ type stateSnapshotStatements struct { bulkSelectStateBlockNIDsStmt *sql.Stmt } -func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func createStateSnapshotTable(db *sql.DB) error { + _, err := db.Exec(stateSnapshotSchema) + return err +} + +func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { s := &stateSnapshotStatements{ db: db, } - _, err := db.Exec(stateSnapshotSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertStateStmt, insertStateSQL}, @@ -70,22 +86,20 @@ func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { } func (s *stateSnapshotStatements) InsertState( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs, ) (stateNID types.StateSnapshotNID, err error) { + stateBlockNIDs = stateBlockNIDs[:util.SortAndUnique(stateBlockNIDs)] stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs) if err != nil { return } insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt) - res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) - if err != nil { - return 0, err - } - lastRowID, err := res.LastInsertId() + var id int64 + err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&id) if err != nil { return 0, err } - stateNID = types.StateSnapshotNID(lastRowID) + stateNID = types.StateSnapshotNID(id) return } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 8e608a6d..c07ab507 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -53,17 +53,22 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) // which it will never obtain. db.SetMaxOpenConns(20) - // Create tables before executing migrations so we don't fail if the table is missing, - // and THEN prepare statements so we don't fail due to referencing new columns - ms := membershipStatements{} - if err := ms.execSchema(db); err != nil { + // Create the tables. + if err := d.create(db); err != nil { return nil, err } + + // Then execute the migrations. By this point the tables are created with the latest + // schemas. m := sqlutil.NewMigrations() deltas.LoadAddForgottenColumn(m) + deltas.LoadStateBlocksRefactor(m) if err := m.RunDeltas(db, dbProperties); err != nil { return nil, err } + + // Then prepare the statements. Now that the migrations have run, any columns referred + // to in the database code should now exist. if err := d.prepare(db, cache); err != nil { return nil, err } @@ -71,62 +76,107 @@ func Open(dbProperties *config.DatabaseOptions, cache caching.RoomServerCaches) return &d, nil } -// nolint: gocyclo +func (d *Database) create(db *sql.DB) error { + if err := createEventStateKeysTable(db); err != nil { + return err + } + if err := createEventTypesTable(db); err != nil { + return err + } + if err := createEventJSONTable(db); err != nil { + return err + } + if err := createEventsTable(db); err != nil { + return err + } + if err := createRoomsTable(db); err != nil { + return err + } + if err := createTransactionsTable(db); err != nil { + return err + } + if err := createStateBlockTable(db); err != nil { + return err + } + if err := createStateSnapshotTable(db); err != nil { + return err + } + if err := createPrevEventsTable(db); err != nil { + return err + } + if err := createRoomAliasesTable(db); err != nil { + return err + } + if err := createInvitesTable(db); err != nil { + return err + } + if err := createMembershipTable(db); err != nil { + return err + } + if err := createPublishedTable(db); err != nil { + return err + } + if err := createRedactionsTable(db); err != nil { + return err + } + + return nil +} + func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { - var err error - eventStateKeys, err := NewSqliteEventStateKeysTable(db) + eventStateKeys, err := prepareEventStateKeysTable(db) if err != nil { return err } - eventTypes, err := NewSqliteEventTypesTable(db) + eventTypes, err := prepareEventTypesTable(db) if err != nil { return err } - eventJSON, err := NewSqliteEventJSONTable(db) + eventJSON, err := prepareEventJSONTable(db) if err != nil { return err } - events, err := NewSqliteEventsTable(db) + events, err := prepareEventsTable(db) if err != nil { return err } - rooms, err := NewSqliteRoomsTable(db) + rooms, err := prepareRoomsTable(db) if err != nil { return err } - transactions, err := NewSqliteTransactionsTable(db) + transactions, err := prepareTransactionsTable(db) if err != nil { return err } - stateBlock, err := NewSqliteStateBlockTable(db) + stateBlock, err := prepareStateBlockTable(db) if err != nil { return err } - stateSnapshot, err := NewSqliteStateSnapshotTable(db) + stateSnapshot, err := prepareStateSnapshotTable(db) if err != nil { return err } - prevEvents, err := NewSqlitePrevEventsTable(db) + prevEvents, err := preparePrevEventsTable(db) if err != nil { return err } - roomAliases, err := NewSqliteRoomAliasesTable(db) + roomAliases, err := prepareRoomAliasesTable(db) if err != nil { return err } - invites, err := NewSqliteInvitesTable(db) + invites, err := prepareInvitesTable(db) if err != nil { return err } - membership, err := NewSqliteMembershipTable(db) + membership, err := prepareMembershipTable(db) if err != nil { return err } - published, err := NewSqlitePublishedTable(db) + published, err := preparePublishedTable(db) if err != nil { return err } - redactions, err := NewSqliteRedactionsTable(db) + redactions, err := prepareRedactionsTable(db) if err != nil { return err } diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index 029122c5..e7471d7b 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -49,14 +49,15 @@ type transactionStatements struct { selectTransactionEventIDStmt *sql.Stmt } -func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { +func createTransactionsTable(db *sql.DB) error { + _, err := db.Exec(transactionsSchema) + return err +} + +func prepareTransactionsTable(db *sql.DB) (tables.Transactions, error) { s := &transactionStatements{ db: db, } - _, err := db.Exec(transactionsSchema) - if err != nil { - return nil, err - } return s, shared.StatementList{ {&s.insertTransactionStmt, insertTransactionSQL}, diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 26bf5cf0..dd486873 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -43,6 +43,7 @@ type Events interface { // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error) // BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. @@ -81,14 +82,14 @@ type Transactions interface { } type StateSnapshot interface { - InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) + InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error) BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) } type StateBlock interface { - BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries []types.StateEntry) (types.StateBlockNID, error) - BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) - BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) + BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error) + BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) + //BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) } type RoomAliases interface { diff --git a/roomserver/types/types.go b/roomserver/types/types.go index e866f6cb..d7e03ad6 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -16,9 +16,11 @@ package types import ( + "encoding/json" "sort" "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/blake2b" ) // EventTypeNID is a numeric ID for an event type. @@ -40,6 +42,38 @@ type StateSnapshotNID int64 // These blocks of state data are combined to form the actual state. type StateBlockNID int64 +// EventNIDs is used to sort and dedupe event NIDs. +type EventNIDs []EventNID + +func (a EventNIDs) Len() int { return len(a) } +func (a EventNIDs) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a EventNIDs) Less(i, j int) bool { return a[i] < a[j] } + +func (a EventNIDs) Hash() []byte { + j, err := json.Marshal(a) + if err != nil { + return nil + } + h := blake2b.Sum256(j) + return h[:] +} + +// StateBlockNIDs is used to sort and dedupe state block NIDs. +type StateBlockNIDs []StateBlockNID + +func (a StateBlockNIDs) Len() int { return len(a) } +func (a StateBlockNIDs) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a StateBlockNIDs) Less(i, j int) bool { return a[i] < a[j] } + +func (a StateBlockNIDs) Hash() []byte { + j, err := json.Marshal(a) + if err != nil { + return nil + } + h := blake2b.Sum256(j) + return h[:] +} + // A StateKeyTuple is a pair of a numeric event type and a numeric state key. // It is used to lookup state entries. type StateKeyTuple struct { @@ -65,6 +99,12 @@ type StateEntry struct { EventNID EventNID } +type StateEntries []StateEntry + +func (a StateEntries) Len() int { return len(a) } +func (a StateEntries) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a StateEntries) Less(i, j int) bool { return a[i].EventNID < a[j].EventNID } + // LessThan returns true if this state entry is less than the other state entry. // The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries. func (a StateEntry) LessThan(b StateEntry) bool { |