diff options
author | Kegsay <kegan@matrix.org> | 2020-02-13 17:27:33 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-02-13 17:27:33 +0000 |
commit | b6ea1bc67ab51667b9e139dd05e0778aca025501 (patch) | |
tree | 18569c317fd28544144c320ce844d93a8ff8ec5e /roomserver | |
parent | 6942ee1de0250235164cf0ce45570b7fc919669d (diff) |
Support sqlite in addition to postgres (#869)
* Move current work into single branch
* Initial massaging of clientapi etc (not working yet)
* Interfaces for accounts/devices databases
* Duplicate postgres package for sqlite3 (no changes made to it yet)
* Some keydb, accountdb, devicedb, common partition fixes, some more syncapi tweaking
* Fix accounts DB, device DB
* Update naffka dependency for SQLite
* Naffka SQLite
* Update naffka to latest master
* SQLite support for federationsender
* Mostly not-bad support for SQLite in syncapi (although there are problems where lots of events get classed incorrectly as backward extremities, probably because of IN/ANY clauses that are badly supported)
* Update Dockerfile -> Go 1.13.7, add build-base (as gcc and friends are needed for SQLite)
* Implement GET endpoints for account_data in clientapi
* Nuke filtering for now...
* Revert "Implement GET endpoints for account_data in clientapi"
This reverts commit 4d80dff4583d278620d9b3ed437e9fcd8d4674ee.
* Implement GET endpoints for account_data in clientapi (#861)
* Implement GET endpoints for account_data in clientapi
* Fix accountDB parameter
* Remove fmt.Println
* Fix insertAccountData SQLite query
* Fix accountDB storage interfaces
* Add empty push rules into account data on account creation (#862)
* Put SaveAccountData into the right function this time
* Not sure if roomserver is better or worse now
* sqlite work
* Allow empty last sent ID for the first event
* sqlite: room creation works
* Support sending messages
* Nuke fmt.println
* Move QueryVariadic etc into common, other device fixes
* Fix some linter issues
* Fix bugs
* Fix some linting errors
* Fix errcheck lint errors
* Make naffka use postgres as fallback, fix couple of compile errors
* What on earth happened to the /rooms/{roomID}/send/{eventType} routing
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
Diffstat (limited to 'roomserver')
20 files changed, 3189 insertions, 11 deletions
diff --git a/roomserver/input/events.go b/roomserver/input/events.go index 03023a4a..a3b70753 100644 --- a/roomserver/input/events.go +++ b/roomserver/input/events.go @@ -196,7 +196,12 @@ func processInviteEvent( return err } succeeded := false - defer common.EndTransaction(updater, &succeeded) + defer func() { + txerr := common.EndTransaction(updater, &succeeded) + if err == nil && txerr != nil { + err = txerr + } + }() if updater.IsJoin() { // If the user is joined to the room then that takes precedence over this diff --git a/roomserver/input/latest_events.go b/roomserver/input/latest_events.go index 7e03d544..f9fd1d5d 100644 --- a/roomserver/input/latest_events.go +++ b/roomserver/input/latest_events.go @@ -60,7 +60,12 @@ func updateLatestEvents( return } succeeded := false - defer common.EndTransaction(updater, &succeeded) + defer func() { + txerr := common.EndTransaction(updater, &succeeded) + if err == nil && txerr != nil { + err = txerr + } + }() u := latestEventsUpdater{ ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID, diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go new file mode 100644 index 00000000..f6c83906 --- /dev/null +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -0,0 +1,108 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const eventJSONSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_event_json ( + event_nid INTEGER NOT NULL PRIMARY KEY, + event_json TEXT NOT NULL + ); +` + +const insertEventJSONSQL = ` + INSERT INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2) + ON CONFLICT DO NOTHING +` + +// Bulk event JSON lookup by numeric event ID. +// Sort by the numeric event ID. +// This means that we can use binary search to lookup by numeric event ID. +const bulkSelectEventJSONSQL = ` + SELECT event_nid, event_json FROM roomserver_event_json + WHERE event_nid IN ($1) + ORDER BY event_nid ASC +` + +type eventJSONStatements struct { + db *sql.DB + insertEventJSONStmt *sql.Stmt + bulkSelectEventJSONStmt *sql.Stmt +} + +func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { + s.db = db + _, err = db.Exec(eventJSONSchema) + if err != nil { + return + } + return statementList{ + {&s.insertEventJSONStmt, insertEventJSONSQL}, + {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, + }.prepare(db) +} + +func (s *eventJSONStatements) insertEventJSON( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, +) error { + _, err := common.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) + return err +} + +type eventJSONPair struct { + EventNID types.EventNID + EventJSON []byte +} + +func (s *eventJSONStatements) bulkSelectEventJSON( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, +) ([]eventJSONPair, error) { + iEventNIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1) + + rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + // We know that we will only get as many results as event NIDs + // because of the unique constraint on event NIDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than NIDs so we adjust the length of the slice before returning it. + results := make([]eventJSONPair, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + var eventNID int64 + if err := rows.Scan(&eventNID, &result.EventJSON); err != nil { + return nil, err + } + result.EventNID = types.EventNID(eventNID) + } + return results[:i], nil +} diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go new file mode 100644 index 00000000..b8bc6c02 --- /dev/null +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -0,0 +1,156 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const eventStateKeysSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_event_state_keys ( + event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT, + event_state_key TEXT NOT NULL UNIQUE + ); + INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key) + VALUES (1, '') + ON CONFLICT DO NOTHING; +` + +// Same as insertEventTypeNIDSQL +const insertEventStateKeyNIDSQL = ` + INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1) + ON CONFLICT DO NOTHING; +` + +const selectEventStateKeyNIDSQL = ` + SELECT event_state_key_nid FROM roomserver_event_state_keys + WHERE event_state_key = $1 +` + +// Bulk lookup from string state key to numeric ID for that state key. +// Takes an array of strings as the query parameter. +const bulkSelectEventStateKeyNIDSQL = ` + SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys + WHERE event_state_key IN ($1) +` + +// Bulk lookup from numeric ID to string state key for that state key. +// Takes an array of strings as the query parameter. +const bulkSelectEventStateKeySQL = ` + SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys + WHERE event_state_key_nid IN ($1) +` + +type eventStateKeyStatements struct { + db *sql.DB + insertEventStateKeyNIDStmt *sql.Stmt + selectEventStateKeyNIDStmt *sql.Stmt + bulkSelectEventStateKeyNIDStmt *sql.Stmt + bulkSelectEventStateKeyStmt *sql.Stmt +} + +func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) { + s.db = db + _, err = db.Exec(eventStateKeysSchema) + if err != nil { + return + } + return statementList{ + {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL}, + {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL}, + {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL}, + {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL}, + }.prepare(db) +} + +func (s *eventStateKeyStatements) insertEventStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + var eventStateKeyNID int64 + var err error + var res sql.Result + insertStmt := txn.Stmt(s.insertEventStateKeyNIDStmt) + if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil { + eventStateKeyNID, err = res.LastInsertId() + } + return types.EventStateKeyNID(eventStateKeyNID), err +} + +func (s *eventStateKeyStatements) selectEventStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + var eventStateKeyNID int64 + stmt := txn.Stmt(s.selectEventStateKeyNIDStmt) + err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID) + return types.EventStateKeyNID(eventStateKeyNID), err +} + +func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + iEventStateKeys := make([]interface{}, len(eventStateKeys)) + for k, v := range eventStateKeys { + iEventStateKeys[k] = v + } + selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeys)), 1) + + rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + result := make(map[string]types.EventStateKeyNID, len(eventStateKeys)) + for rows.Next() { + var stateKey string + var stateKeyNID int64 + if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { + return nil, err + } + result[stateKey] = types.EventStateKeyNID(stateKeyNID) + } + return result, nil +} + +func (s *eventStateKeyStatements) bulkSelectEventStateKey( + ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID, +) (map[types.EventStateKeyNID]string, error) { + iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs)) + for k, v := range eventStateKeyNIDs { + iEventStateKeyNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeyNIDs)), 1) + + rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs)) + for rows.Next() { + var stateKey string + var stateKeyNID int64 + if err := rows.Scan(&stateKey, &stateKeyNID); err != nil { + return nil, err + } + result[types.EventStateKeyNID(stateKeyNID)] = stateKey + } + return result, nil +} diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go new file mode 100644 index 00000000..edc06d4c --- /dev/null +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -0,0 +1,153 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const eventTypesSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_event_types ( + event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT, + event_type TEXT NOT NULL UNIQUE + ); + INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES + (1, 'm.room.create'), + (2, 'm.room.power_levels'), + (3, 'm.room.join_rules'), + (4, 'm.room.third_party_invite'), + (5, 'm.room.member'), + (6, 'm.room.redaction'), + (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING; +` + +// Assign a new numeric event type ID. +// The usual case is that the event type is not in the database. +// In that case the ID will be assigned using the next value from the sequence. +// We use `RETURNING` to tell postgres to return the assigned ID. +// But it's possible that the type was added in a query that raced with us. +// This will result in a conflict on the event_type_unique constraint, in this +// case we do nothing. Postgresql won't return a row in that case so we rely on +// the caller catching the sql.ErrNoRows error and running a select to get the row. +// We could get postgresql to return the row on a conflict by updating the row +// but it doesn't seem like a good idea to modify the rows just to make postgresql +// return it. Modifying the rows will cause postgres to assign a new tuple for the +// row even though the data doesn't change resulting in unncesssary modifications +// to the indexes. +const insertEventTypeNIDSQL = ` + INSERT INTO roomserver_event_types (event_type) VALUES ($1) + ON CONFLICT DO NOTHING; +` + +const insertEventTypeNIDResultSQL = ` + SELECT event_type_nid FROM roomserver_event_types + WHERE rowid = last_insert_rowid(); +` + +const selectEventTypeNIDSQL = ` + SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1 +` + +// Bulk lookup from string event type to numeric ID for that event type. +// Takes an array of strings as the query parameter. +const bulkSelectEventTypeNIDSQL = ` + SELECT event_type, event_type_nid FROM roomserver_event_types + WHERE event_type IN ($1) +` + +type eventTypeStatements struct { + db *sql.DB + insertEventTypeNIDStmt *sql.Stmt + insertEventTypeNIDResultStmt *sql.Stmt + selectEventTypeNIDStmt *sql.Stmt + bulkSelectEventTypeNIDStmt *sql.Stmt +} + +func (s *eventTypeStatements) prepare(db *sql.DB) (err error) { + s.db = db + _, err = db.Exec(eventTypesSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL}, + {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL}, + {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL}, + {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL}, + }.prepare(db) +} + +func (s *eventTypeStatements) insertEventTypeNID( + ctx context.Context, tx *sql.Tx, eventType string, +) (types.EventTypeNID, error) { + var eventTypeNID int64 + var err error + insertStmt := common.TxStmt(tx, s.insertEventTypeNIDStmt) + resultStmt := common.TxStmt(tx, s.insertEventTypeNIDResultStmt) + if _, err = insertStmt.ExecContext(ctx, eventType); err == nil { + err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID) + } + return types.EventTypeNID(eventTypeNID), err +} + +func (s *eventTypeStatements) selectEventTypeNID( + ctx context.Context, tx *sql.Tx, eventType string, +) (types.EventTypeNID, error) { + var eventTypeNID int64 + selectStmt := common.TxStmt(tx, s.selectEventTypeNIDStmt) + err := selectStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID) + return types.EventTypeNID(eventTypeNID), err +} + +func (s *eventTypeStatements) bulkSelectEventTypeNID( + ctx context.Context, tx *sql.Tx, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + /////////////// + iEventTypes := make([]interface{}, len(eventTypes)) + for k, v := range eventTypes { + iEventTypes[k] = v + } + selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", common.QueryVariadic(len(iEventTypes)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := common.TxStmt(tx, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventTypes...) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + result := make(map[string]types.EventTypeNID, len(eventTypes)) + for rows.Next() { + var eventType string + var eventTypeNID int64 + if err := rows.Scan(&eventType, &eventTypeNID); err != nil { + return nil, err + } + result[eventType] = types.EventTypeNID(eventTypeNID) + } + return result, nil +} diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go new file mode 100644 index 00000000..4ed1395d --- /dev/null +++ b/roomserver/storage/sqlite3/events_table.go @@ -0,0 +1,479 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const eventsSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_events ( + event_nid INTEGER PRIMARY KEY AUTOINCREMENT, + room_nid INTEGER NOT NULL, + event_type_nid INTEGER NOT NULL, + event_state_key_nid INTEGER NOT NULL, + sent_to_output BOOLEAN NOT NULL DEFAULT FALSE, + state_snapshot_nid INTEGER NOT NULL DEFAULT 0, + depth INTEGER NOT NULL, + event_id TEXT NOT NULL UNIQUE, + reference_sha256 BLOB NOT NULL, + auth_event_nids TEXT NOT NULL DEFAULT '{}' + ); +` + +const insertEventSQL = ` + INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT DO NOTHING; +` + +const insertEventResultSQL = ` + SELECT event_nid, state_snapshot_nid FROM roomserver_events + WHERE rowid = last_insert_rowid(); +` + +const selectEventSQL = "" + + "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" + +// Bulk lookup of events by string ID. +// Sort by the numeric IDs for event type and state key. +// This means we can use binary search to lookup entries by type and state key. +const bulkSelectStateEventByIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_id IN ($1)" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + +const bulkSelectStateAtEventByIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM roomserver_events" + + " WHERE event_id IN ($1)" + +const updateEventStateSQL = "" + + "UPDATE roomserver_events SET state_snapshot_nid = $1 WHERE event_nid = $2" + +const selectEventSentToOutputSQL = "" + + "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1" + +const updateEventSentToOutputSQL = "" + + "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1" + +const selectEventIDSQL = "" + + "SELECT event_id FROM roomserver_events WHERE event_nid = $1" + +const bulkSelectStateAtEventAndReferenceSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" + + " FROM roomserver_events WHERE event_nid IN ($1)" + +const bulkSelectEventReferenceSQL = "" + + "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid IN ($1)" + +const bulkSelectEventIDSQL = "" + + "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)" + +const bulkSelectEventNIDSQL = "" + + "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)" + +const selectMaxEventDepthSQL = "" + + "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)" + +type eventStatements struct { + db *sql.DB + insertEventStmt *sql.Stmt + insertEventResultStmt *sql.Stmt + selectEventStmt *sql.Stmt + bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateAtEventByIDStmt *sql.Stmt + updateEventStateStmt *sql.Stmt + selectEventSentToOutputStmt *sql.Stmt + updateEventSentToOutputStmt *sql.Stmt + selectEventIDStmt *sql.Stmt + bulkSelectStateAtEventAndReferenceStmt *sql.Stmt + bulkSelectEventReferenceStmt *sql.Stmt + bulkSelectEventIDStmt *sql.Stmt + bulkSelectEventNIDStmt *sql.Stmt + selectMaxEventDepthStmt *sql.Stmt +} + +func (s *eventStatements) prepare(db *sql.DB) (err error) { + s.db = db + _, err = db.Exec(eventsSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertEventStmt, insertEventSQL}, + {&s.insertEventResultStmt, insertEventResultSQL}, + {&s.selectEventStmt, selectEventSQL}, + {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, + {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, + {&s.updateEventStateStmt, updateEventStateSQL}, + {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, + {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL}, + {&s.selectEventIDStmt, selectEventIDSQL}, + {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, + {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, + {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, + {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, + {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, + }.prepare(db) +} + +func (s *eventStatements) insertEvent( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + eventTypeNID types.EventTypeNID, + eventStateKeyNID types.EventStateKeyNID, + eventID string, + referenceSHA256 []byte, + authEventNIDs []types.EventNID, + depth int64, +) (types.EventNID, types.StateSnapshotNID, error) { + var eventNID int64 + var stateNID int64 + var err error + insertStmt := common.TxStmt(txn, s.insertEventStmt) + resultStmt := common.TxStmt(txn, s.insertEventResultStmt) + if _, err = insertStmt.ExecContext( + ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), + eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, + ); err == nil { + err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID) + } + return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err +} + +func (s *eventStatements) selectEvent( + ctx context.Context, txn *sql.Tx, eventID string, +) (types.EventNID, types.StateSnapshotNID, error) { + var eventNID int64 + var stateNID int64 + selectStmt := common.TxStmt(txn, s.selectEventStmt) + err := selectStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID) + return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err +} + +// bulkSelectStateEventByID lookups a list of state events by event ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError +func (s *eventStatements) bulkSelectStateEventByID( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) ([]types.StateEntry, error) { + /////////////// + iEventIDs := make([]interface{}, len(eventIDs)) + for k, v := range eventIDs { + iEventIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1) + selectPrep, err := txn.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventIDs...) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + // We know that we will only get as many results as event IDs + // because of the unique constraint on event IDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than IDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + ); err != nil { + return nil, err + } + } + if i != len(eventIDs) { + // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. + // We don't know which ones were missing because we don't return the string IDs in the query. + // However it should be possible debug this by replaying queries or entries from the input kafka logs. + // If this turns out to be impossible and we do need the debug information here, it would be better + // to do it as a separate query rather than slowing down/complicating the common case. + return nil, types.MissingEventError( + fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)), + ) + } + return results, err +} + +// bulkSelectStateAtEventByID lookups the state at a list of events by event ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError. +// If we do not have the state for any of the requested events it returns a types.MissingEventError. +func (s *eventStatements) bulkSelectStateAtEventByID( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) ([]types.StateAtEvent, error) { + /////////////// + iEventIDs := make([]interface{}, len(eventIDs)) + for k, v := range eventIDs { + iEventIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1) + selectPrep, err := txn.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventIDs...) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]types.StateAtEvent, len(eventIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + &result.BeforeStateSnapshotNID, + ); err != nil { + return nil, err + } + if result.BeforeStateSnapshotNID == 0 { + return nil, types.MissingEventError( + fmt.Sprintf("storage: missing state for event NID %d", result.EventNID), + ) + } + } + if i != len(eventIDs) { + return nil, types.MissingEventError( + fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)), + ) + } + return results, err +} + +func (s *eventStatements) updateEventState( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + updateStmt := common.TxStmt(txn, s.updateEventStateStmt) + _, err := updateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) + return err +} + +func (s *eventStatements) selectEventSentToOutput( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (sentToOutput bool, err error) { + selectStmt := common.TxStmt(txn, s.selectEventSentToOutputStmt) + err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) + //err = s.selectEventSentToOutputStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput) + if err != nil { + } + return +} + +func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { + updateStmt := common.TxStmt(txn, s.updateEventSentToOutputStmt) + _, err := updateStmt.ExecContext(ctx, int64(eventNID)) + //_, err := s.updateEventSentToOutputStmt.ExecContext(ctx, int64(eventNID)) + return err +} + +func (s *eventStatements) selectEventID( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, +) (eventID string, err error) { + selectStmt := common.TxStmt(txn, s.selectEventIDStmt) + err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID) + return +} + +func (s *eventStatements) bulkSelectStateAtEventAndReference( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, +) ([]types.StateAtEventAndReference, error) { + /////////////// + iEventNIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1) + ////////////// + + rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]types.StateAtEventAndReference, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + var ( + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + stateSnapshotNID int64 + eventID string + eventSHA256 []byte + ) + if err = rows.Scan( + &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256, + ); err != nil { + return nil, err + } + result := &results[i] + result.EventTypeNID = types.EventTypeNID(eventTypeNID) + result.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + result.EventNID = types.EventNID(eventNID) + result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID) + result.EventID = eventID + result.EventSHA256 = eventSHA256 + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + +func (s *eventStatements) bulkSelectEventReference( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, +) ([]gomatrixserverlib.EventReference, error) { + /////////////// + iEventNIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1) + selectPrep, err := txn.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]gomatrixserverlib.EventReference, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil { + return nil, err + } + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + +// bulkSelectEventID returns a map from numeric event ID to string event ID. +func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { + /////////////// + iEventNIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1) + selectPrep, err := txn.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make(map[types.EventNID]string, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + var eventNID int64 + var eventID string + if err = rows.Scan(&eventNID, &eventID); err != nil { + return nil, err + } + results[types.EventNID(eventNID)] = eventID + } + if i != len(eventNIDs) { + return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs)) + } + return results, nil +} + +// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) { + /////////////// + iEventIDs := make([]interface{}, len(eventIDs)) + for k, v := range eventIDs { + iEventIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1) + selectPrep, err := txn.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, iEventIDs...) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make(map[string]types.EventNID, len(eventIDs)) + for rows.Next() { + var eventID string + var eventNID int64 + if err = rows.Scan(&eventID, &eventNID); err != nil { + return nil, err + } + results[eventID] = types.EventNID(eventNID) + } + return results, nil +} + +func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { + var result int64 + selectStmt := common.TxStmt(txn, s.selectMaxEventDepthStmt) + err := selectStmt.QueryRowContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))).Scan(&result) + if err != nil { + return 0, err + } + return result, nil +} + +func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { + nids := make([]int64, len(eventNIDs)) + for i := range eventNIDs { + nids[i] = int64(eventNIDs[i]) + } + return nids +} diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go new file mode 100644 index 00000000..5a0f0bf7 --- /dev/null +++ b/roomserver/storage/sqlite3/invite_table.go @@ -0,0 +1,142 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const inviteSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_invites ( + invite_event_id TEXT PRIMARY KEY, + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + retired BOOLEAN NOT NULL DEFAULT FALSE, + invite_event_json TEXT NOT NULL + ); + + CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid) + WHERE NOT retired; +` +const insertInviteEventSQL = "" + + "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," + + " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" + + " ON CONFLICT DO NOTHING" + +const selectInviteActiveForUserInRoomSQL = "" + + "SELECT sender_nid FROM roomserver_invites" + + " WHERE target_nid = $1 AND room_nid = $2" + + " AND NOT retired" + +// Retire every active invite for a user in a room. +// Ideally we'd know which invite events were retired by a given update so we +// wouldn't need to remove every active invite. +// However the matrix protocol doesn't give us a way to reliably identify the +// invites that were retired, so we are forced to retire all of them. +const updateInviteRetiredSQL = ` + UPDATE roomserver_invites SET retired = TRUE + WHERE room_nid = $1 AND target_nid = $2 AND NOT retired; + SELECT invite_event_id FROM roomserver_invites + WHERE rowid = last_insert_rowid(); +` + +type inviteStatements struct { + insertInviteEventStmt *sql.Stmt + selectInviteActiveForUserInRoomStmt *sql.Stmt + updateInviteRetiredStmt *sql.Stmt +} + +func (s *inviteStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(inviteSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL}, + {&s.updateInviteRetiredStmt, updateInviteRetiredSQL}, + }.prepare(db) +} + +func (s *inviteStatements) insertInviteEvent( + ctx context.Context, + txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, + targetUserNID, senderUserNID types.EventStateKeyNID, + inviteEventJSON []byte, +) (bool, error) { + stmt := common.TxStmt(txn, s.insertInviteEventStmt) + defer stmt.Close() // nolint: errcheck + result, err := stmt.ExecContext( + ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, + ) + if err != nil { + return false, err + } + count, err := result.RowsAffected() + if err != nil { + return false, err + } + return count != 0, nil +} + +func (s *inviteStatements) updateInviteRetired( + ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) (eventIDs []string, err error) { + stmt := common.TxStmt(txn, s.updateInviteRetiredStmt) + rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) + if err != nil { + return nil, err + } + defer (func() { err = rows.Close() })() + for rows.Next() { + var inviteEventID string + if err := rows.Scan(&inviteEventID); err != nil { + return nil, err + } + eventIDs = append(eventIDs, inviteEventID) + } + return +} + +// selectInviteActiveForUserInRoom returns a list of sender state key NIDs +func (s *inviteStatements) selectInviteActiveForUserInRoom( + ctx context.Context, + targetUserNID types.EventStateKeyNID, roomNID types.RoomNID, +) ([]types.EventStateKeyNID, error) { + rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext( + ctx, targetUserNID, roomNID, + ) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + var result []types.EventStateKeyNID + for rows.Next() { + var senderUserNID int64 + if err := rows.Scan(&senderUserNID); err != nil { + return nil, err + } + result = append(result, types.EventStateKeyNID(senderUserNID)) + } + return result, nil +} diff --git a/roomserver/storage/sqlite3/list.go b/roomserver/storage/sqlite3/list.go new file mode 100644 index 00000000..4fe4e334 --- /dev/null +++ b/roomserver/storage/sqlite3/list.go @@ -0,0 +1,18 @@ +package sqlite3 + +import ( + "strconv" + "strings" + + "github.com/lib/pq" +) + +type SqliteList string + +func sqliteIn(a pq.Int64Array) string { + var b []string + for _, n := range a { + b = append(b, strconv.FormatInt(n, 10)) + } + return strings.Join(b, ",") +} diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go new file mode 100644 index 00000000..97877673 --- /dev/null +++ b/roomserver/storage/sqlite3/membership_table.go @@ -0,0 +1,180 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +type membershipState int64 + +const ( + membershipStateLeaveOrBan membershipState = 1 + membershipStateInvite membershipState = 2 + membershipStateJoin membershipState = 3 +) + +const membershipSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_membership ( + room_nid INTEGER NOT NULL, + target_nid INTEGER NOT NULL, + sender_nid INTEGER NOT NULL DEFAULT 0, + membership_nid INTEGER NOT NULL DEFAULT 1, + event_nid INTEGER NOT NULL DEFAULT 0, + UNIQUE (room_nid, target_nid) + ); +` + +// Insert a row in to membership table so that it can be locked by the +// SELECT FOR UPDATE +const insertMembershipSQL = "" + + "INSERT INTO roomserver_membership (room_nid, target_nid)" + + " VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectMembershipFromRoomAndTargetSQL = "" + + "SELECT membership_nid, event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND target_nid = $2" + +const selectMembershipsFromRoomAndMembershipSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND membership_nid = $2" + +const selectMembershipsFromRoomSQL = "" + + "SELECT event_nid FROM roomserver_membership" + + " WHERE room_nid = $1" + +const selectMembershipForUpdateSQL = "" + + "SELECT membership_nid FROM roomserver_membership" + + " WHERE room_nid = $1 AND target_nid = $2" + +const updateMembershipSQL = "" + + "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" + + " WHERE room_nid = $4 AND target_nid = $5" + +type membershipStatements struct { + insertMembershipStmt *sql.Stmt + selectMembershipForUpdateStmt *sql.Stmt + selectMembershipFromRoomAndTargetStmt *sql.Stmt + selectMembershipsFromRoomAndMembershipStmt *sql.Stmt + selectMembershipsFromRoomStmt *sql.Stmt + updateMembershipStmt *sql.Stmt +} + +func (s *membershipStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(membershipSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertMembershipStmt, insertMembershipSQL}, + {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL}, + {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL}, + {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL}, + {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL}, + {&s.updateMembershipStmt, updateMembershipSQL}, + }.prepare(db) +} + +func (s *membershipStatements) insertMembership( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) error { + stmt := common.TxStmt(txn, s.insertMembershipStmt) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID) + return err +} + +func (s *membershipStatements) selectMembershipForUpdate( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) (membership membershipState, err error) { + stmt := common.TxStmt(txn, s.selectMembershipForUpdateStmt) + err = stmt.QueryRowContext( + ctx, roomNID, targetUserNID, + ).Scan(&membership) + return +} + +func (s *membershipStatements) selectMembershipFromRoomAndTarget( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, +) (eventNID types.EventNID, membership membershipState, err error) { + selectStmt := common.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt) + err = selectStmt.QueryRowContext( + ctx, roomNID, targetUserNID, + ).Scan(&membership, &eventNID) + return +} + +func (s *membershipStatements) selectMembershipsFromRoom( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, +) (eventNIDs []types.EventNID, err error) { + selectStmt := common.TxStmt(txn, s.selectMembershipsFromRoomStmt) + rows, err := selectStmt.QueryContext(ctx, roomNID) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + for rows.Next() { + var eNID types.EventNID + if err = rows.Scan(&eNID); err != nil { + return + } + eventNIDs = append(eventNIDs, eNID) + } + return +} +func (s *membershipStatements) selectMembershipsFromRoomAndMembership( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, membership membershipState, +) (eventNIDs []types.EventNID, err error) { + stmt := common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt) + rows, err := stmt.QueryContext(ctx, roomNID, membership) + if err != nil { + return + } + defer rows.Close() // nolint: errcheck + + for rows.Next() { + var eNID types.EventNID + if err = rows.Scan(&eNID); err != nil { + return + } + eventNIDs = append(eventNIDs, eNID) + } + return +} + +func (s *membershipStatements) updateMembership( + ctx context.Context, txn *sql.Tx, + roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, + senderUserNID types.EventStateKeyNID, membership membershipState, + eventNID types.EventNID, +) error { + stmt := common.TxStmt(txn, s.updateMembershipStmt) + _, err := stmt.ExecContext( + ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, + ) + return err +} diff --git a/roomserver/storage/sqlite3/prepare.go b/roomserver/storage/sqlite3/prepare.go new file mode 100644 index 00000000..482dfa2b --- /dev/null +++ b/roomserver/storage/sqlite3/prepare.go @@ -0,0 +1,36 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "database/sql" +) + +// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. +type statementList []struct { + statement **sql.Stmt + sql string +} + +// prepare the SQL for each statement in the list and assign the result to the prepared statement. +func (s statementList) prepare(db *sql.DB) (err error) { + for _, statement := range s { + if *statement.statement, err = db.Prepare(statement.sql); err != nil { + return + } + } + return +} diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go new file mode 100644 index 00000000..9ed64a38 --- /dev/null +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -0,0 +1,92 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const previousEventSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_previous_events ( + previous_event_id TEXT NOT NULL, + previous_reference_sha256 BLOB NOT NULL, + event_nids TEXT NOT NULL, + UNIQUE (previous_event_id, previous_reference_sha256) + ); +` + +// Insert an entry into the previous_events table. +// If there is already an entry indicating that an event references that previous event then +// add the event NID to the list to indicate that this event references that previous event as well. +// This should only be modified while holding a "FOR UPDATE" lock on the row in the rooms table for this room. +// The lock is necessary to avoid data races when checking whether an event is already referenced by another event. +const insertPreviousEventSQL = ` + INSERT OR REPLACE INTO roomserver_previous_events + (previous_event_id, previous_reference_sha256, event_nids) + VALUES ($1, $2, $3) +` + +// Check if the event is referenced by another event in the table. +// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. +const selectPreviousEventExistsSQL = ` + SELECT 1 FROM roomserver_previous_events + WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 +` + +type previousEventStatements struct { + insertPreviousEventStmt *sql.Stmt + selectPreviousEventExistsStmt *sql.Stmt +} + +func (s *previousEventStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(previousEventSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertPreviousEventStmt, insertPreviousEventSQL}, + {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, + }.prepare(db) +} + +func (s *previousEventStatements) insertPreviousEvent( + ctx context.Context, + txn *sql.Tx, + previousEventID string, + previousEventReferenceSHA256 []byte, + eventNID types.EventNID, +) error { + stmt := common.TxStmt(txn, s.insertPreviousEventStmt) + _, err := stmt.ExecContext( + ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), + ) + return err +} + +// Check if the event reference exists +// Returns sql.ErrNoRows if the event reference doesn't exist. +func (s *previousEventStatements) selectPreviousEventExists( + ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte, +) error { + var ok int64 + stmt := common.TxStmt(txn, s.selectPreviousEventExistsStmt) + return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok) +} diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go new file mode 100644 index 00000000..a5fd5449 --- /dev/null +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -0,0 +1,135 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" +) + +const roomAliasesSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_room_aliases ( + alias TEXT NOT NULL PRIMARY KEY, + room_id TEXT NOT NULL, + creator_id TEXT NOT NULL + ); + + CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id); +` + +const insertRoomAliasSQL = ` + INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3) +` + +const selectRoomIDFromAliasSQL = ` + SELECT room_id FROM roomserver_room_aliases WHERE alias = $1 +` + +const selectAliasesFromRoomIDSQL = ` + SELECT alias FROM roomserver_room_aliases WHERE room_id = $1 +` + +const selectCreatorIDFromAliasSQL = ` + SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1 +` + +const deleteRoomAliasSQL = ` + DELETE FROM roomserver_room_aliases WHERE alias = $1 +` + +type roomAliasesStatements struct { + insertRoomAliasStmt *sql.Stmt + selectRoomIDFromAliasStmt *sql.Stmt + selectAliasesFromRoomIDStmt *sql.Stmt + selectCreatorIDFromAliasStmt *sql.Stmt + deleteRoomAliasStmt *sql.Stmt +} + +func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(roomAliasesSchema) + if err != nil { + return + } + return statementList{ + {&s.insertRoomAliasStmt, insertRoomAliasSQL}, + {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL}, + {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL}, + {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL}, + {&s.deleteRoomAliasStmt, deleteRoomAliasSQL}, + }.prepare(db) +} + +func (s *roomAliasesStatements) insertRoomAlias( + ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string, +) (err error) { + insertStmt := common.TxStmt(txn, s.insertRoomAliasStmt) + _, err = insertStmt.ExecContext(ctx, alias, roomID, creatorUserID) + return +} + +func (s *roomAliasesStatements) selectRoomIDFromAlias( + ctx context.Context, txn *sql.Tx, alias string, +) (roomID string, err error) { + selectStmt := common.TxStmt(txn, s.selectRoomIDFromAliasStmt) + err = selectStmt.QueryRowContext(ctx, alias).Scan(&roomID) + if err == sql.ErrNoRows { + return "", nil + } + return +} + +func (s *roomAliasesStatements) selectAliasesFromRoomID( + ctx context.Context, txn *sql.Tx, roomID string, +) (aliases []string, err error) { + aliases = []string{} + selectStmt := common.TxStmt(txn, s.selectAliasesFromRoomIDStmt) + rows, err := selectStmt.QueryContext(ctx, roomID) + if err != nil { + return + } + + for rows.Next() { + var alias string + if err = rows.Scan(&alias); err != nil { + return + } + + aliases = append(aliases, alias) + } + + return +} + +func (s *roomAliasesStatements) selectCreatorIDFromAlias( + ctx context.Context, txn *sql.Tx, alias string, +) (creatorID string, err error) { + selectStmt := common.TxStmt(txn, s.selectCreatorIDFromAliasStmt) + err = selectStmt.QueryRowContext(ctx, alias).Scan(&creatorID) + if err == sql.ErrNoRows { + return "", nil + } + return +} + +func (s *roomAliasesStatements) deleteRoomAlias( + ctx context.Context, txn *sql.Tx, alias string, +) (err error) { + deleteStmt := common.TxStmt(txn, s.deleteRoomAliasStmt) + _, err = deleteStmt.ExecContext(ctx, alias) + return +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go new file mode 100644 index 00000000..bf237728 --- /dev/null +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -0,0 +1,165 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const roomsSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_rooms ( + room_nid INTEGER PRIMARY KEY AUTOINCREMENT, + room_id TEXT NOT NULL UNIQUE, + latest_event_nids TEXT NOT NULL DEFAULT '{}', + last_event_sent_nid INTEGER NOT NULL DEFAULT 0, + state_snapshot_nid INTEGER NOT NULL DEFAULT 0, + room_version INTEGER NOT NULL DEFAULT 1 + ); +` + +// Same as insertEventTypeNIDSQL +const insertRoomNIDSQL = ` + INSERT INTO roomserver_rooms (room_id) VALUES ($1) + ON CONFLICT DO NOTHING; +` + +const selectRoomNIDSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" + +const selectLatestEventNIDsSQL = "" + + "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" + +const selectLatestEventNIDsForUpdateSQL = "" + + "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" + +const updateLatestEventNIDsSQL = "" + + "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4" + +const selectRoomVersionForRoomNIDSQL = "" + + "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1" + +type roomStatements struct { + insertRoomNIDStmt *sql.Stmt + selectRoomNIDStmt *sql.Stmt + selectLatestEventNIDsStmt *sql.Stmt + selectLatestEventNIDsForUpdateStmt *sql.Stmt + updateLatestEventNIDsStmt *sql.Stmt + selectRoomVersionForRoomNIDStmt *sql.Stmt +} + +func (s *roomStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(roomsSchema) + if err != nil { + return + } + return statementList{ + {&s.insertRoomNIDStmt, insertRoomNIDSQL}, + {&s.selectRoomNIDStmt, selectRoomNIDSQL}, + {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, + {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, + {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, + {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL}, + }.prepare(db) +} + +func (s *roomStatements) insertRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var err error + insertStmt := common.TxStmt(txn, s.insertRoomNIDStmt) + if _, err = insertStmt.ExecContext(ctx, roomID); err == nil { + return s.selectRoomNID(ctx, txn, roomID) + } else { + return types.RoomNID(0), err + } +} + +func (s *roomStatements) selectRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := common.TxStmt(txn, s.selectRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + +func (s *roomStatements) selectLatestEventNIDs( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) ([]types.EventNID, types.StateSnapshotNID, error) { + var nids pq.Int64Array + var stateSnapshotNID int64 + stmt := common.TxStmt(txn, s.selectLatestEventNIDsStmt) + err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID) + if err != nil { + return nil, 0, err + } + eventNIDs := make([]types.EventNID, len(nids)) + for i := range nids { + eventNIDs[i] = types.EventNID(nids[i]) + } + return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil +} + +func (s *roomStatements) selectLatestEventsNIDsForUpdate( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) { + var nids pq.Int64Array + var lastEventSentNID int64 + var stateSnapshotNID int64 + stmt := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt) + err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID) + if err != nil { + return nil, 0, 0, err + } + eventNIDs := make([]types.EventNID, len(nids)) + for i := range nids { + eventNIDs[i] = types.EventNID(nids[i]) + } + return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil +} + +func (s *roomStatements) updateLatestEventNIDs( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + eventNIDs []types.EventNID, + lastEventSentNID types.EventNID, + stateSnapshotNID types.StateSnapshotNID, +) error { + stmt := common.TxStmt(txn, s.updateLatestEventNIDsStmt) + _, err := stmt.ExecContext( + ctx, + eventNIDsAsArray(eventNIDs), + int64(lastEventSentNID), + int64(stateSnapshotNID), + roomNID, + ) + return err +} + +func (s *roomStatements) selectRoomVersionForRoomNID( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) (int64, error) { + var roomVersion int64 + stmt := common.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt) + err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion) + return roomVersion, err +} diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go new file mode 100644 index 00000000..0d49432b --- /dev/null +++ b/roomserver/storage/sqlite3/sql.go @@ -0,0 +1,60 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "database/sql" +) + +type statements struct { + eventTypeStatements + eventStateKeyStatements + roomStatements + eventStatements + eventJSONStatements + stateSnapshotStatements + stateBlockStatements + previousEventStatements + roomAliasesStatements + inviteStatements + membershipStatements + transactionStatements +} + +func (s *statements) prepare(db *sql.DB) error { + var err error + + for _, prepare := range []func(db *sql.DB) error{ + s.eventTypeStatements.prepare, + s.eventStateKeyStatements.prepare, + s.roomStatements.prepare, + s.eventStatements.prepare, + s.eventJSONStatements.prepare, + s.stateSnapshotStatements.prepare, + s.stateBlockStatements.prepare, + s.previousEventStatements.prepare, + s.roomAliasesStatements.prepare, + s.inviteStatements.prepare, + s.membershipStatements.prepare, + s.transactionStatements.prepare, + } { + if err = prepare(db); err != nil { + return err + } + } + + return nil +} diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go new file mode 100644 index 00000000..ac593546 --- /dev/null +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -0,0 +1,292 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "sort" + "strings" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" +) + +const stateDataSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_state_block ( + state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT, + event_type_nid INTEGER NOT NULL, + event_state_key_nid INTEGER NOT NULL, + event_nid INTEGER NOT NULL, + UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) + ); +` + +const insertStateDataSQL = "" + + "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + + " VALUES ($1, $2, $3, $4)" + +const selectNextStateBlockNIDSQL = ` + SELECT COALESCE(( + SELECT seq+1 AS state_block_nid FROM sqlite_sequence + WHERE name = 'roomserver_state_block'), 1 + ) AS state_block_nid +` + +// Bulk state lookup by numeric state block ID. +// Sort by the state_block_nid, event_type_nid, event_state_key_nid +// This means that all the entries for a given state_block_nid will appear +// together in the list and those entries will sorted by event_type_nid +// and event_state_key_nid. This property makes it easier to merge two +// state data blocks together. +const bulkSelectStateBlockEntriesSQL = "" + + "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + + " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + + " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + +// Bulk state lookup by numeric state block ID. +// Filters the rows in each block to the requested types and state keys. +// We would like to restrict to particular type state key pairs but we are +// restricted by the query language to pull the cross product of a list +// of types and a list state_keys. So we have to filter the result in the +// application to restrict it to the list of event types and state keys we +// actually wanted. +const bulkSelectFilteredStateBlockEntriesSQL = "" + + "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + + " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + + " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + + " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + +type stateBlockStatements struct { + db *sql.DB + insertStateDataStmt *sql.Stmt + selectNextStateBlockNIDStmt *sql.Stmt + bulkSelectStateBlockEntriesStmt *sql.Stmt + bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt +} + +func (s *stateBlockStatements) prepare(db *sql.DB) (err error) { + s.db = db + _, err = db.Exec(stateDataSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertStateDataStmt, insertStateDataSQL}, + {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, + {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, + {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, + }.prepare(db) +} + +func (s *stateBlockStatements) bulkInsertStateData( + ctx context.Context, txn *sql.Tx, + stateBlockNID types.StateBlockNID, + entries []types.StateEntry, +) error { + for _, entry := range entries { + _, err := common.TxStmt(txn, s.insertStateDataStmt).ExecContext( + ctx, + int64(stateBlockNID), + int64(entry.EventTypeNID), + int64(entry.EventStateKeyNID), + int64(entry.EventNID), + ) + if err != nil { + return err + } + } + return nil +} + +func (s *stateBlockStatements) selectNextStateBlockNID( + ctx context.Context, + txn *sql.Tx, +) (types.StateBlockNID, error) { + var stateBlockNID int64 + selectStmt := common.TxStmt(txn, s.selectNextStateBlockNIDStmt) + err := selectStmt.QueryRowContext(ctx).Scan(&stateBlockNID) + return types.StateBlockNID(stateBlockNID), err +} + +func (s *stateBlockStatements) bulkSelectStateBlockEntries( + ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + nids := make([]interface{}, len(stateBlockNIDs)) + for k, v := range stateBlockNIDs { + nids[k] = v + } + selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(nids)), 1) + selectPrep, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + selectStmt := common.TxStmt(txn, selectPrep) + rows, err := selectStmt.QueryContext(ctx, nids...) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + results := make([]types.StateEntryList, len(stateBlockNIDs)) + // current is a pointer to the StateEntryList to append the state entries to. + var current *types.StateEntryList + i := 0 + for rows.Next() { + var ( + stateBlockNID int64 + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + entry types.StateEntry + ) + if err := rows.Scan( + &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, + ); err != nil { + return nil, err + } + entry.EventTypeNID = types.EventTypeNID(eventTypeNID) + entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + entry.EventNID = types.EventNID(eventNID) + if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { + // The state entry row is for a different state data block to the current one. + // So we start appending to the next entry in the list. + current = &results[i] + current.StateBlockNID = types.StateBlockNID(stateBlockNID) + i++ + } + current.StateEntries = append(current.StateEntries, entry) + } + if i != len(nids) { + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids)) + } + return results, nil +} + +func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries( + ctx context.Context, txn *sql.Tx, // nolint: unparam + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + tuples := stateKeyTupleSorter(stateKeyTuples) + // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. + sort.Sort(tuples) + + eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() + sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(stateBlockNIDs)), 1) + sqlStatement = strings.Replace(sqlStatement, "($2)", common.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) + sqlStatement = strings.Replace(sqlStatement, "($3)", common.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) + + var params []interface{} + for _, val := range stateBlockNIDs { + params = append(params, int64(val)) + } + for _, val := range eventTypeNIDArray { + params = append(params, val) + } + for _, val := range eventStateKeyNIDArray { + params = append(params, val) + } + + rows, err := s.db.QueryContext( + ctx, + sqlStatement, + params..., + ) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + + var results []types.StateEntryList + var current types.StateEntryList + for rows.Next() { + var ( + stateBlockNID int64 + eventTypeNID int64 + eventStateKeyNID int64 + eventNID int64 + entry types.StateEntry + ) + if err := rows.Scan( + &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, + ); err != nil { + return nil, err + } + entry.EventTypeNID = types.EventTypeNID(eventTypeNID) + entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) + entry.EventNID = types.EventNID(eventNID) + + // We can use binary search here because we sorted the tuples earlier + if !tuples.contains(entry.StateKeyTuple) { + // The select will return the cross product of types and state keys. + // So we need to check if type of the entry is in the list. + continue + } + + if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { + // The state entry row is for a different state data block to the current one. + // So we append the current entry to the results and start adding to a new one. + // The first time through the loop current will be empty. + if current.StateEntries != nil { + results = append(results, current) + } + current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} + } + current.StateEntries = append(current.StateEntries, entry) + } + // Add the last entry to the list if it is not empty. + if current.StateEntries != nil { + results = append(results, current) + } + return results, nil +} + +type stateKeyTupleSorter []types.StateKeyTuple + +func (s stateKeyTupleSorter) Len() int { return len(s) } +func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } +func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Check whether a tuple is in the list. Assumes that the list is sorted. +func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool { + i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) }) + return i < len(s) && s[i] == value +} + +// List the unique eventTypeNIDs and eventStateKeyNIDs. +// Assumes that the list is sorted. +func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) { + eventTypeNIDs = make(pq.Int64Array, len(s)) + eventStateKeyNIDs = make(pq.Int64Array, len(s)) + for i := range s { + eventTypeNIDs[i] = int64(s[i].EventTypeNID) + eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID) + } + eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))] + eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))] + return +} + +type int64Sorter []int64 + +func (s int64Sorter) Len() int { return len(s) } +func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } +func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/sqlite3/state_block_table_test.go b/roomserver/storage/sqlite3/state_block_table_test.go new file mode 100644 index 00000000..98439f5c --- /dev/null +++ b/roomserver/storage/sqlite3/state_block_table_test.go @@ -0,0 +1,86 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "sort" + "testing" + + "github.com/matrix-org/dendrite/roomserver/types" +) + +func TestStateKeyTupleSorter(t *testing.T) { + input := stateKeyTupleSorter{ + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 1}, + } + want := []types.StateKeyTuple{ + {EventTypeNID: 1, EventStateKeyNID: 1}, + {EventTypeNID: 1, EventStateKeyNID: 2}, + {EventTypeNID: 1, EventStateKeyNID: 4}, + {EventTypeNID: 2, EventStateKeyNID: 2}, + } + doNotWant := []types.StateKeyTuple{ + {EventTypeNID: 0, EventStateKeyNID: 0}, + {EventTypeNID: 1, EventStateKeyNID: 3}, + {EventTypeNID: 2, EventStateKeyNID: 1}, + {EventTypeNID: 3, EventStateKeyNID: 1}, + } + wantTypeNIDs := []int64{1, 2} + wantStateKeyNIDs := []int64{1, 2, 4} + + // Sort the input and check it's in the right order. + sort.Sort(input) + gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays() + + for i := range want { + if input[i] != want[i] { + t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i]) + } + + if !input.contains(want[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i]) + } + } + + for i := range doNotWant { + if input.contains(doNotWant[i]) { + t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i]) + } + } + + if len(wantTypeNIDs) != len(gotTypeNIDs) { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + + for i := range wantTypeNIDs { + if wantTypeNIDs[i] != gotTypeNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } + + if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) { + t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs) + } + + for i := range wantStateKeyNIDs { + if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] { + t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) + } + } +} diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go new file mode 100644 index 00000000..df97aa41 --- /dev/null +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -0,0 +1,120 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const stateSnapshotSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( + state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, + room_nid INTEGER NOT NULL, + state_block_nids TEXT NOT NULL DEFAULT '{}' + ); +` + +const insertStateSQL = ` + INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) + VALUES ($1, $2);` + +// Bulk state data NID lookup. +// Sorting by state_snapshot_nid means we can use binary search over the result +// to lookup the state data NIDs for a state snapshot NID. +const bulkSelectStateBlockNIDsSQL = "" + + "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + + " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" + +type stateSnapshotStatements struct { + db *sql.DB + insertStateStmt *sql.Stmt + bulkSelectStateBlockNIDsStmt *sql.Stmt +} + +func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) { + s.db = db + _, err = db.Exec(stateSnapshotSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertStateStmt, insertStateSQL}, + {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, + }.prepare(db) +} + +func (s *stateSnapshotStatements) insertState( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, +) (stateNID types.StateSnapshotNID, err error) { + nids := make([]int64, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + nids[i] = int64(stateBlockNIDs[i]) + } + insertStmt := txn.Stmt(s.insertStateStmt) + if res, err2 := insertStmt.ExecContext(ctx, int64(roomNID), pq.Int64Array(nids)); err2 == nil { + lastRowID, err3 := res.LastInsertId() + if err3 != nil { + err = err3 + } + stateNID = types.StateSnapshotNID(lastRowID) + } + return +} + +func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs( + ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + nids := make([]interface{}, len(stateNIDs)) + for k, v := range stateNIDs { + nids[k] = v + } + selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", common.QueryVariadic(len(nids)), 1) + selectStmt, err := txn.Prepare(selectOrig) + if err != nil { + return nil, err + } + + rows, err := selectStmt.QueryContext(ctx, nids...) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + results := make([]types.StateBlockNIDList, len(stateNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + var stateBlockNIDs pq.Int64Array + if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil { + return nil, err + } + result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs)) + for k := range stateBlockNIDs { + result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k]) + } + } + if i != len(stateNIDs) { + return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) + } + return results, nil +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go new file mode 100644 index 00000000..e20e8aed --- /dev/null +++ b/roomserver/storage/sqlite3/storage.go @@ -0,0 +1,864 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "net/url" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + _ "github.com/mattn/go-sqlite3" +) + +// A Database is used to store room events and stream offsets. +type Database struct { + statements statements + db *sql.DB +} + +// Open a postgres database. +func Open(dataSourceName string) (*Database, error) { + var d Database + uri, err := url.Parse(dataSourceName) + if err != nil { + return nil, err + } + var cs string + if uri.Opaque != "" { // file:filename.db + cs = uri.Opaque + } else if uri.Path != "" { // file:///path/to/filename.db + cs = uri.Path + } else { + return nil, errors.New("no filename or path in connect string") + } + if d.db, err = sql.Open("sqlite3", cs); err != nil { + return nil, err + } + //d.db.Exec("PRAGMA journal_mode=WAL;") + //d.db.Exec("PRAGMA read_uncommitted = true;") + d.db.SetMaxOpenConns(2) + if err = d.statements.prepare(d.db); err != nil { + return nil, err + } + return &d, nil +} + +// StoreEvent implements input.EventDatabase +func (d *Database) StoreEvent( + ctx context.Context, event gomatrixserverlib.Event, + txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, +) (types.RoomNID, types.StateAtEvent, error) { + var ( + roomNID types.RoomNID + eventTypeNID types.EventTypeNID + eventStateKeyNID types.EventStateKeyNID + eventNID types.EventNID + stateNID types.StateSnapshotNID + err error + ) + + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + if txnAndSessionID != nil { + if err = d.statements.insertTransaction( + ctx, txn, txnAndSessionID.TransactionID, + txnAndSessionID.SessionID, event.Sender(), event.EventID(), + ); err != nil { + return err + } + } + + if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID()); err != nil { + return err + } + + if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil { + return err + } + + eventStateKey := event.StateKey() + // Assigned a numeric ID for the state_key if there is one present. + // Otherwise set the numeric ID for the state_key to 0. + if eventStateKey != nil { + if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { + return err + } + } + + if eventNID, stateNID, err = d.statements.insertEvent( + ctx, + txn, + roomNID, + eventTypeNID, + eventStateKeyNID, + event.EventID(), + event.EventReference().EventSHA256, + authEventNIDs, + event.Depth(), + ); err != nil { + if err == sql.ErrNoRows { + // We've already inserted the event so select the numeric event ID + eventNID, stateNID, err = d.statements.selectEvent(ctx, txn, event.EventID()) + } + if err != nil { + return err + } + } + + if err = d.statements.insertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { + return err + } + + return nil + }) + if err != nil { + return 0, types.StateAtEvent{}, err + } + + return roomNID, types.StateAtEvent{ + BeforeStateSnapshotNID: stateNID, + StateEntry: types.StateEntry{ + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: eventTypeNID, + EventStateKeyNID: eventStateKeyNID, + }, + EventNID: eventNID, + }, + }, nil +} + +func (d *Database) assignRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (roomNID types.RoomNID, err error) { + // Check if we already have a numeric ID in the database. + roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID) + if err == nil { + // Now get the numeric ID back out of the database + roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) + } + } + return +} + +func (d *Database) assignEventTypeNID( + ctx context.Context, txn *sql.Tx, eventType string, +) (eventTypeNID types.EventTypeNID, err error) { + // Check if we already have a numeric ID in the database. + eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventTypeNID, err = d.statements.insertEventTypeNID(ctx, txn, eventType) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType) + } + } + return +} + +func (d *Database) assignStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (eventStateKeyNID types.EventStateKeyNID, err error) { + // Check if we already have a numeric ID in the database. + eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + } + } + return +} + +// StateEntriesForEventIDs implements input.EventDatabase +func (d *Database) StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, +) (se []types.StateEntry, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + se, err = d.statements.bulkSelectStateEventByID(ctx, txn, eventIDs) + return err + }) + return +} + +// EventTypeNIDs implements state.RoomStateDatabase +func (d *Database) EventTypeNIDs( + ctx context.Context, eventTypes []string, +) (etnids map[string]types.EventTypeNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + etnids, err = d.statements.bulkSelectEventTypeNID(ctx, txn, eventTypes) + return err + }) + return +} + +// EventStateKeyNIDs implements state.RoomStateDatabase +func (d *Database) EventStateKeyNIDs( + ctx context.Context, eventStateKeys []string, +) (esknids map[string]types.EventStateKeyNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + esknids, err = d.statements.bulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) + return err + }) + return +} + +// EventStateKeys implements query.RoomserverQueryAPIDatabase +func (d *Database) EventStateKeys( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, +) (out map[types.EventStateKeyNID]string, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + out, err = d.statements.bulkSelectEventStateKey(ctx, txn, eventStateKeyNIDs) + return err + }) + return +} + +// EventNIDs implements query.RoomserverQueryAPIDatabase +func (d *Database) EventNIDs( + ctx context.Context, eventIDs []string, +) (out map[string]types.EventNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + out, err = d.statements.bulkSelectEventNID(ctx, txn, eventIDs) + return err + }) + return +} + +// Events implements input.EventDatabase +func (d *Database) Events( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.Event, error) { + var eventJSONs []eventJSONPair + var err error + results := make([]types.Event, len(eventNIDs)) + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs) + if err != nil || len(eventJSONs) == 0 { + return nil + } + for i, eventJSON := range eventJSONs { + result := &results[i] + result.EventNID = eventJSON.EventNID + // TODO: Use NewEventFromTrustedJSON for efficiency + result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON) + if err != nil { + return nil + } + } + return nil + }) + if err != nil { + return []types.Event{}, err + } + return results, nil +} + +// AddState implements input.EventDatabase +func (d *Database) AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, +) (stateNID types.StateSnapshotNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + if len(state) > 0 { + var stateBlockNID types.StateBlockNID + stateBlockNID, err = d.statements.selectNextStateBlockNID(ctx, txn) + if err != nil { + return err + } + if err = d.statements.bulkInsertStateData(ctx, txn, stateBlockNID, state); err != nil { + return err + } + stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) + } + stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs) + return err + }) + if err != nil { + return 0, err + } + return +} + +// SetState implements input.EventDatabase +func (d *Database) SetState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + e := common.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.statements.updateEventState(ctx, txn, eventNID, stateNID) + }) + return e +} + +// StateAtEventIDs implements input.EventDatabase +func (d *Database) StateAtEventIDs( + ctx context.Context, eventIDs []string, +) (se []types.StateAtEvent, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + se, err = d.statements.bulkSelectStateAtEventByID(ctx, txn, eventIDs) + return err + }) + return +} + +// StateBlockNIDs implements state.RoomStateDatabase +func (d *Database) StateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) (sl []types.StateBlockNIDList, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + sl, err = d.statements.bulkSelectStateBlockNIDs(ctx, txn, stateNIDs) + return err + }) + return +} + +// StateEntries implements state.RoomStateDatabase +func (d *Database) StateEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) (sel []types.StateEntryList, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + sel, err = d.statements.bulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs) + return err + }) + return +} + +// SnapshotNIDFromEventID implements state.RoomStateDatabase +func (d *Database) SnapshotNIDFromEventID( + ctx context.Context, eventID string, +) (stateNID types.StateSnapshotNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + _, stateNID, err = d.statements.selectEvent(ctx, txn, eventID) + return err + }) + return +} + +// EventIDs implements input.RoomEventDatabase +func (d *Database) EventIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (out map[types.EventNID]string, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + out, err = d.statements.bulkSelectEventID(ctx, txn, eventNIDs) + return err + }) + return +} + +// GetLatestEventsForUpdate implements input.EventDatabase +func (d *Database) GetLatestEventsForUpdate( + ctx context.Context, roomNID types.RoomNID, +) (types.RoomRecentEventsUpdater, error) { + txn, err := d.db.Begin() + if err != nil { + return nil, err + } + eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := + d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + var lastEventIDSent string + if lastEventNIDSent != 0 { + lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + } + + // FIXME: we probably want to support long-lived txns in sqlite somehow, but we don't because we get + // 'database is locked' errors caused by multiple write txns (one being the long-lived txn created here) + // so for now let's not use a long-lived txn at all, and just commit it here and set the txn to nil so + // we fail fast if someone tries to use the underlying txn object. + err = txn.Commit() + if err != nil { + return nil, err + } + return &roomRecentEventsUpdater{ + transaction{ctx, nil}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + }, nil +} + +// GetTransactionEventID implements input.EventDatabase +func (d *Database) GetTransactionEventID( + ctx context.Context, transactionID string, + sessionID int64, userID string, +) (string, error) { + eventID, err := d.statements.selectTransactionEventID(ctx, nil, transactionID, sessionID, userID) + if err == sql.ErrNoRows { + return "", nil + } + return eventID, err +} + +type roomRecentEventsUpdater struct { + transaction + d *Database + roomNID types.RoomNID + latestEvents []types.StateAtEventAndReference + lastEventIDSent string + currentStateSnapshotNID types.StateSnapshotNID +} + +// LatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { + return u.latestEvents +} + +// LastEventIDSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LastEventIDSent() string { + return u.lastEventIDSent +} + +// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { + return u.currentStateSnapshotNID +} + +// StorePreviousEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { + err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + for _, ref := range previousEventReferences { + if err := u.d.statements.insertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return err + } + } + return nil + }) + return err +} + +// IsReferenced implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (res bool, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + err := u.d.statements.selectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256) + if err == nil { + res = true + err = nil + } + if err == sql.ErrNoRows { + res = false + err = nil + } + return err + }) + return +} + +// SetLatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) SetLatestEvents( + roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, + currentStateSnapshotNID types.StateSnapshotNID, +) error { + err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + eventNIDs := make([]types.EventNID, len(latest)) + for i := range latest { + eventNIDs[i] = latest[i].EventNID + } + return u.d.statements.updateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) + }) + return err +} + +// HasEventBeenSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (res bool, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + res, err = u.d.statements.selectEventSentToOutput(u.ctx, txn, eventNID) + return err + }) + return +} + +// MarkEventAsSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { + err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + return u.d.statements.updateEventSentToOutput(u.ctx, txn, eventNID) + }) + return err +} + +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID) + return err + }) + return +} + +// RoomNID implements query.RoomserverQueryAPIDB +func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + roomNID = 0 + err = nil + } + return err + }) + return +} + +// LatestEventIDs implements query.RoomserverQueryAPIDatabase +func (d *Database) LatestEventIDs( + ctx context.Context, roomNID types.RoomNID, +) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + var eventNIDs []types.EventNID + eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, txn, roomNID) + if err != nil { + return err + } + references, err = d.statements.bulkSelectEventReference(ctx, txn, eventNIDs) + if err != nil { + return err + } + depth, err = d.statements.selectMaxEventDepth(ctx, txn, eventNIDs) + if err != nil { + return err + } + return nil + }) + return +} + +// GetInvitesForUser implements query.RoomserverQueryAPIDatabase +func (d *Database) GetInvitesForUser( + ctx context.Context, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (senderUserIDs []types.EventStateKeyNID, err error) { + return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) +} + +// SetRoomAlias implements alias.RoomserverAliasAPIDB +func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { + return d.statements.insertRoomAlias(ctx, nil, alias, roomID, creatorUserID) +} + +// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB +func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { + return d.statements.selectRoomIDFromAlias(ctx, nil, alias) +} + +// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB +func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { + return d.statements.selectAliasesFromRoomID(ctx, nil, roomID) +} + +// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB +func (d *Database) GetCreatorIDForAlias( + ctx context.Context, alias string, +) (string, error) { + return d.statements.selectCreatorIDFromAlias(ctx, nil, alias) +} + +// RemoveRoomAlias implements alias.RoomserverAliasAPIDB +func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { + return d.statements.deleteRoomAlias(ctx, nil, alias) +} + +// StateEntriesForTuples implements state.RoomStateDatabase +func (d *Database) StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return d.statements.bulkSelectFilteredStateBlockEntries( + ctx, nil, stateBlockNIDs, stateKeyTuples, + ) +} + +// MembershipUpdater implements input.RoomEventDatabase +func (d *Database) MembershipUpdater( + ctx context.Context, roomID, targetUserID string, +) (types.MembershipUpdater, error) { + txn, err := d.db.Begin() + if err != nil { + return nil, err + } + succeeded := false + defer func() { + if !succeeded { + txn.Rollback() // nolint: errcheck + } + }() + + roomNID, err := d.assignRoomNID(ctx, txn, roomID) + if err != nil { + return nil, err + } + + targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) + if err != nil { + return nil, err + } + + updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) + if err != nil { + return nil, err + } + + succeeded = true + return updater, nil +} + +type membershipUpdater struct { + transaction + d *Database + roomNID types.RoomNID + targetUserNID types.EventStateKeyNID + membership membershipState +} + +func (d *Database) membershipUpdaterTxn( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (types.MembershipUpdater, error) { + + if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { + return nil, err + } + + membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) + if err != nil { + return nil, err + } + + return &membershipUpdater{ + transaction{ctx, txn}, d, roomNID, targetUserNID, membership, + }, nil +} + +// IsInvite implements types.MembershipUpdater +func (u *membershipUpdater) IsInvite() bool { + return u.membership == membershipStateInvite +} + +// IsJoin implements types.MembershipUpdater +func (u *membershipUpdater) IsJoin() bool { + return u.membership == membershipStateJoin +} + +// IsLeave implements types.MembershipUpdater +func (u *membershipUpdater) IsLeave() bool { + return u.membership == membershipStateLeaveOrBan +} + +// SetToInvite implements types.MembershipUpdater +func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted bool, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, event.Sender()) + if err != nil { + return err + } + inserted, err = u.d.statements.insertInviteEvent( + u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), + ) + if err != nil { + return err + } + if u.membership != membershipStateInvite { + if err = u.d.statements.updateMembership( + u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, + ); err != nil { + return err + } + } + return nil + }) + return +} + +// SetToJoin implements types.MembershipUpdater +func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID) + if err != nil { + return err + } + + // If this is a join event update, there is no invite to update + if !isUpdate { + inviteEventIDs, err = u.d.statements.updateInviteRetired( + u.ctx, txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return err + } + } + + // Look up the NID of the new join event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return err + } + + if u.membership != membershipStateJoin || isUpdate { + if err = u.d.statements.updateMembership( + u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateJoin, nIDs[eventID], + ); err != nil { + return err + } + } + return nil + }) + + return +} + +// SetToLeave implements types.MembershipUpdater +func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID) + if err != nil { + return err + } + inviteEventIDs, err = u.d.statements.updateInviteRetired( + u.ctx, txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return err + } + + // Look up the NID of the new leave event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return err + } + + if u.membership != membershipStateLeaveOrBan { + if err = u.d.statements.updateMembership( + u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateLeaveOrBan, nIDs[eventID], + ); err != nil { + return err + } + } + return nil + }) + return +} + +// GetMembership implements query.RoomserverQueryAPIDB +func (d *Database) GetMembership( + ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, +) (membershipEventNID types.EventNID, stillInRoom bool, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + requestSenderUserNID, err := d.assignStateKeyNID(ctx, txn, requestSenderUserID) + if err != nil { + return err + } + + membershipEventNID, _, err = + d.statements.selectMembershipFromRoomAndTarget( + ctx, txn, roomNID, requestSenderUserNID, + ) + if err == sql.ErrNoRows { + // The user has never been a member of that room + return nil + } + if err != nil { + return err + } + stillInRoom = true + return nil + }) + + return +} + +// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB +func (d *Database) GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, +) (eventNIDs []types.EventNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + if joinOnly { + eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership( + ctx, txn, roomNID, membershipStateJoin, + ) + return nil + } + + eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID) + return nil + }) + return +} + +// EventsFromIDs implements query.RoomserverQueryAPIEventDB +func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + nidMap, err := d.EventNIDs(ctx, eventIDs) + if err != nil { + return nil, err + } + + var nids []types.EventNID + for _, nid := range nidMap { + nids = append(nids, nid) + } + + return d.Events(ctx, nids) +} + +func (d *Database) GetRoomVersionForRoom( + ctx context.Context, roomNID types.RoomNID, +) (int64, error) { + return d.statements.selectRoomVersionForRoomNID( + ctx, nil, roomNID, + ) +} + +type transaction struct { + ctx context.Context + txn *sql.Tx +} + +// Commit implements types.Transaction +func (t *transaction) Commit() error { + if t.txn == nil { + return nil + } + return t.txn.Commit() +} + +// Rollback implements types.Transaction +func (t *transaction) Rollback() error { + if t.txn == nil { + return nil + } + return t.txn.Rollback() +} diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go new file mode 100644 index 00000000..7740e5f0 --- /dev/null +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -0,0 +1,86 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" +) + +const transactionsSchema = ` + CREATE TABLE IF NOT EXISTS roomserver_transactions ( + transaction_id TEXT NOT NULL, + session_id INTEGER NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + PRIMARY KEY (transaction_id, session_id, user_id) + ); +` +const insertTransactionSQL = ` + INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id) + VALUES ($1, $2, $3, $4) +` + +const selectTransactionEventIDSQL = ` + SELECT event_id FROM roomserver_transactions + WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3 +` + +type transactionStatements struct { + insertTransactionStmt *sql.Stmt + selectTransactionEventIDStmt *sql.Stmt +} + +func (s *transactionStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(transactionsSchema) + if err != nil { + return + } + + return statementList{ + {&s.insertTransactionStmt, insertTransactionSQL}, + {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL}, + }.prepare(db) +} + +func (s *transactionStatements) insertTransaction( + ctx context.Context, txn *sql.Tx, + transactionID string, + sessionID int64, + userID string, + eventID string, +) (err error) { + stmt := common.TxStmt(txn, s.insertTransactionStmt) + _, err = stmt.ExecContext( + ctx, transactionID, sessionID, userID, eventID, + ) + return +} + +func (s *transactionStatements) selectTransactionEventID( + ctx context.Context, txn *sql.Tx, + transactionID string, + sessionID int64, + userID string, +) (eventID string, err error) { + stmt := common.TxStmt(txn, s.selectTransactionEventIDStmt) + err = stmt.QueryRowContext( + ctx, transactionID, sessionID, userID, + ).Scan(&eventID) + return +} diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index 90841168..551d97cd 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -19,25 +19,20 @@ import ( "net/url" "github.com/matrix-org/dendrite/roomserver/api" + statedb "github.com/matrix-org/dendrite/roomserver/state/database" "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { + statedb.RoomStateDatabase StoreEvent(ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) - EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) - EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) - Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) - AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error - StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) - StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) - StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) - SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) GetTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (string, error) @@ -49,7 +44,6 @@ type Database interface { GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) GetCreatorIDForAlias(ctx context.Context, alias string) (string, error) RemoveRoomAlias(ctx context.Context, alias string) error - StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) MembershipUpdater(ctx context.Context, roomID, targetUserID string) (types.MembershipUpdater, error) GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) @@ -66,6 +60,8 @@ func Open(dataSourceName string) (Database, error) { switch uri.Scheme { case "postgres": return postgres.Open(dataSourceName) + case "file": + return sqlite3.Open(dataSourceName) default: return postgres.Open(dataSourceName) } |