aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/sqlite3/storage.go
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-02-13 17:27:33 +0000
committerGitHub <noreply@github.com>2020-02-13 17:27:33 +0000
commitb6ea1bc67ab51667b9e139dd05e0778aca025501 (patch)
tree18569c317fd28544144c320ce844d93a8ff8ec5e /roomserver/storage/sqlite3/storage.go
parent6942ee1de0250235164cf0ce45570b7fc919669d (diff)
Support sqlite in addition to postgres (#869)
* Move current work into single branch * Initial massaging of clientapi etc (not working yet) * Interfaces for accounts/devices databases * Duplicate postgres package for sqlite3 (no changes made to it yet) * Some keydb, accountdb, devicedb, common partition fixes, some more syncapi tweaking * Fix accounts DB, device DB * Update naffka dependency for SQLite * Naffka SQLite * Update naffka to latest master * SQLite support for federationsender * Mostly not-bad support for SQLite in syncapi (although there are problems where lots of events get classed incorrectly as backward extremities, probably because of IN/ANY clauses that are badly supported) * Update Dockerfile -> Go 1.13.7, add build-base (as gcc and friends are needed for SQLite) * Implement GET endpoints for account_data in clientapi * Nuke filtering for now... * Revert "Implement GET endpoints for account_data in clientapi" This reverts commit 4d80dff4583d278620d9b3ed437e9fcd8d4674ee. * Implement GET endpoints for account_data in clientapi (#861) * Implement GET endpoints for account_data in clientapi * Fix accountDB parameter * Remove fmt.Println * Fix insertAccountData SQLite query * Fix accountDB storage interfaces * Add empty push rules into account data on account creation (#862) * Put SaveAccountData into the right function this time * Not sure if roomserver is better or worse now * sqlite work * Allow empty last sent ID for the first event * sqlite: room creation works * Support sending messages * Nuke fmt.println * Move QueryVariadic etc into common, other device fixes * Fix some linter issues * Fix bugs * Fix some linting errors * Fix errcheck lint errors * Make naffka use postgres as fallback, fix couple of compile errors * What on earth happened to the /rooms/{roomID}/send/{eventType} routing Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
Diffstat (limited to 'roomserver/storage/sqlite3/storage.go')
-rw-r--r--roomserver/storage/sqlite3/storage.go864
1 files changed, 864 insertions, 0 deletions
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go
new file mode 100644
index 00000000..e20e8aed
--- /dev/null
+++ b/roomserver/storage/sqlite3/storage.go
@@ -0,0 +1,864 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "net/url"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+ _ "github.com/mattn/go-sqlite3"
+)
+
+// A Database is used to store room events and stream offsets.
+type Database struct {
+ statements statements
+ db *sql.DB
+}
+
+// Open a postgres database.
+func Open(dataSourceName string) (*Database, error) {
+ var d Database
+ uri, err := url.Parse(dataSourceName)
+ if err != nil {
+ return nil, err
+ }
+ var cs string
+ if uri.Opaque != "" { // file:filename.db
+ cs = uri.Opaque
+ } else if uri.Path != "" { // file:///path/to/filename.db
+ cs = uri.Path
+ } else {
+ return nil, errors.New("no filename or path in connect string")
+ }
+ if d.db, err = sql.Open("sqlite3", cs); err != nil {
+ return nil, err
+ }
+ //d.db.Exec("PRAGMA journal_mode=WAL;")
+ //d.db.Exec("PRAGMA read_uncommitted = true;")
+ d.db.SetMaxOpenConns(2)
+ if err = d.statements.prepare(d.db); err != nil {
+ return nil, err
+ }
+ return &d, nil
+}
+
+// StoreEvent implements input.EventDatabase
+func (d *Database) StoreEvent(
+ ctx context.Context, event gomatrixserverlib.Event,
+ txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
+) (types.RoomNID, types.StateAtEvent, error) {
+ var (
+ roomNID types.RoomNID
+ eventTypeNID types.EventTypeNID
+ eventStateKeyNID types.EventStateKeyNID
+ eventNID types.EventNID
+ stateNID types.StateSnapshotNID
+ err error
+ )
+
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if txnAndSessionID != nil {
+ if err = d.statements.insertTransaction(
+ ctx, txn, txnAndSessionID.TransactionID,
+ txnAndSessionID.SessionID, event.Sender(), event.EventID(),
+ ); err != nil {
+ return err
+ }
+ }
+
+ if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID()); err != nil {
+ return err
+ }
+
+ if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
+ return err
+ }
+
+ eventStateKey := event.StateKey()
+ // Assigned a numeric ID for the state_key if there is one present.
+ // Otherwise set the numeric ID for the state_key to 0.
+ if eventStateKey != nil {
+ if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
+ return err
+ }
+ }
+
+ if eventNID, stateNID, err = d.statements.insertEvent(
+ ctx,
+ txn,
+ roomNID,
+ eventTypeNID,
+ eventStateKeyNID,
+ event.EventID(),
+ event.EventReference().EventSHA256,
+ authEventNIDs,
+ event.Depth(),
+ ); err != nil {
+ if err == sql.ErrNoRows {
+ // We've already inserted the event so select the numeric event ID
+ eventNID, stateNID, err = d.statements.selectEvent(ctx, txn, event.EventID())
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ if err = d.statements.insertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
+ return err
+ }
+
+ return nil
+ })
+ if err != nil {
+ return 0, types.StateAtEvent{}, err
+ }
+
+ return roomNID, types.StateAtEvent{
+ BeforeStateSnapshotNID: stateNID,
+ StateEntry: types.StateEntry{
+ StateKeyTuple: types.StateKeyTuple{
+ EventTypeNID: eventTypeNID,
+ EventStateKeyNID: eventStateKeyNID,
+ },
+ EventNID: eventNID,
+ },
+ }, nil
+}
+
+func (d *Database) assignRoomNID(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) (roomNID types.RoomNID, err error) {
+ // Check if we already have a numeric ID in the database.
+ roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
+ if err == sql.ErrNoRows {
+ // We don't have a numeric ID so insert one into the database.
+ roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID)
+ if err == nil {
+ // Now get the numeric ID back out of the database
+ roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
+ }
+ }
+ return
+}
+
+func (d *Database) assignEventTypeNID(
+ ctx context.Context, txn *sql.Tx, eventType string,
+) (eventTypeNID types.EventTypeNID, err error) {
+ // Check if we already have a numeric ID in the database.
+ eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType)
+ if err == sql.ErrNoRows {
+ // We don't have a numeric ID so insert one into the database.
+ eventTypeNID, err = d.statements.insertEventTypeNID(ctx, txn, eventType)
+ if err == sql.ErrNoRows {
+ // We raced with another insert so run the select again.
+ eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType)
+ }
+ }
+ return
+}
+
+func (d *Database) assignStateKeyNID(
+ ctx context.Context, txn *sql.Tx, eventStateKey string,
+) (eventStateKeyNID types.EventStateKeyNID, err error) {
+ // Check if we already have a numeric ID in the database.
+ eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
+ if err == sql.ErrNoRows {
+ // We don't have a numeric ID so insert one into the database.
+ eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey)
+ if err == sql.ErrNoRows {
+ // We raced with another insert so run the select again.
+ eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
+ }
+ }
+ return
+}
+
+// StateEntriesForEventIDs implements input.EventDatabase
+func (d *Database) StateEntriesForEventIDs(
+ ctx context.Context, eventIDs []string,
+) (se []types.StateEntry, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ se, err = d.statements.bulkSelectStateEventByID(ctx, txn, eventIDs)
+ return err
+ })
+ return
+}
+
+// EventTypeNIDs implements state.RoomStateDatabase
+func (d *Database) EventTypeNIDs(
+ ctx context.Context, eventTypes []string,
+) (etnids map[string]types.EventTypeNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ etnids, err = d.statements.bulkSelectEventTypeNID(ctx, txn, eventTypes)
+ return err
+ })
+ return
+}
+
+// EventStateKeyNIDs implements state.RoomStateDatabase
+func (d *Database) EventStateKeyNIDs(
+ ctx context.Context, eventStateKeys []string,
+) (esknids map[string]types.EventStateKeyNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ esknids, err = d.statements.bulkSelectEventStateKeyNID(ctx, txn, eventStateKeys)
+ return err
+ })
+ return
+}
+
+// EventStateKeys implements query.RoomserverQueryAPIDatabase
+func (d *Database) EventStateKeys(
+ ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
+) (out map[types.EventStateKeyNID]string, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ out, err = d.statements.bulkSelectEventStateKey(ctx, txn, eventStateKeyNIDs)
+ return err
+ })
+ return
+}
+
+// EventNIDs implements query.RoomserverQueryAPIDatabase
+func (d *Database) EventNIDs(
+ ctx context.Context, eventIDs []string,
+) (out map[string]types.EventNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ out, err = d.statements.bulkSelectEventNID(ctx, txn, eventIDs)
+ return err
+ })
+ return
+}
+
+// Events implements input.EventDatabase
+func (d *Database) Events(
+ ctx context.Context, eventNIDs []types.EventNID,
+) ([]types.Event, error) {
+ var eventJSONs []eventJSONPair
+ var err error
+ results := make([]types.Event, len(eventNIDs))
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs)
+ if err != nil || len(eventJSONs) == 0 {
+ return nil
+ }
+ for i, eventJSON := range eventJSONs {
+ result := &results[i]
+ result.EventNID = eventJSON.EventNID
+ // TODO: Use NewEventFromTrustedJSON for efficiency
+ result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON)
+ if err != nil {
+ return nil
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ return []types.Event{}, err
+ }
+ return results, nil
+}
+
+// AddState implements input.EventDatabase
+func (d *Database) AddState(
+ ctx context.Context,
+ roomNID types.RoomNID,
+ stateBlockNIDs []types.StateBlockNID,
+ state []types.StateEntry,
+) (stateNID types.StateSnapshotNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if len(state) > 0 {
+ var stateBlockNID types.StateBlockNID
+ stateBlockNID, err = d.statements.selectNextStateBlockNID(ctx, txn)
+ if err != nil {
+ return err
+ }
+ if err = d.statements.bulkInsertStateData(ctx, txn, stateBlockNID, state); err != nil {
+ return err
+ }
+ stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
+ }
+ stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs)
+ return err
+ })
+ if err != nil {
+ return 0, err
+ }
+ return
+}
+
+// SetState implements input.EventDatabase
+func (d *Database) SetState(
+ ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
+) error {
+ e := common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.statements.updateEventState(ctx, txn, eventNID, stateNID)
+ })
+ return e
+}
+
+// StateAtEventIDs implements input.EventDatabase
+func (d *Database) StateAtEventIDs(
+ ctx context.Context, eventIDs []string,
+) (se []types.StateAtEvent, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ se, err = d.statements.bulkSelectStateAtEventByID(ctx, txn, eventIDs)
+ return err
+ })
+ return
+}
+
+// StateBlockNIDs implements state.RoomStateDatabase
+func (d *Database) StateBlockNIDs(
+ ctx context.Context, stateNIDs []types.StateSnapshotNID,
+) (sl []types.StateBlockNIDList, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ sl, err = d.statements.bulkSelectStateBlockNIDs(ctx, txn, stateNIDs)
+ return err
+ })
+ return
+}
+
+// StateEntries implements state.RoomStateDatabase
+func (d *Database) StateEntries(
+ ctx context.Context, stateBlockNIDs []types.StateBlockNID,
+) (sel []types.StateEntryList, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ sel, err = d.statements.bulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
+ return err
+ })
+ return
+}
+
+// SnapshotNIDFromEventID implements state.RoomStateDatabase
+func (d *Database) SnapshotNIDFromEventID(
+ ctx context.Context, eventID string,
+) (stateNID types.StateSnapshotNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ _, stateNID, err = d.statements.selectEvent(ctx, txn, eventID)
+ return err
+ })
+ return
+}
+
+// EventIDs implements input.RoomEventDatabase
+func (d *Database) EventIDs(
+ ctx context.Context, eventNIDs []types.EventNID,
+) (out map[types.EventNID]string, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ out, err = d.statements.bulkSelectEventID(ctx, txn, eventNIDs)
+ return err
+ })
+ return
+}
+
+// GetLatestEventsForUpdate implements input.EventDatabase
+func (d *Database) GetLatestEventsForUpdate(
+ ctx context.Context, roomNID types.RoomNID,
+) (types.RoomRecentEventsUpdater, error) {
+ txn, err := d.db.Begin()
+ if err != nil {
+ return nil, err
+ }
+ eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
+ d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
+ if err != nil {
+ txn.Rollback() // nolint: errcheck
+ return nil, err
+ }
+ stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
+ if err != nil {
+ txn.Rollback() // nolint: errcheck
+ return nil, err
+ }
+ var lastEventIDSent string
+ if lastEventNIDSent != 0 {
+ lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent)
+ if err != nil {
+ txn.Rollback() // nolint: errcheck
+ return nil, err
+ }
+ }
+
+ // FIXME: we probably want to support long-lived txns in sqlite somehow, but we don't because we get
+ // 'database is locked' errors caused by multiple write txns (one being the long-lived txn created here)
+ // so for now let's not use a long-lived txn at all, and just commit it here and set the txn to nil so
+ // we fail fast if someone tries to use the underlying txn object.
+ err = txn.Commit()
+ if err != nil {
+ return nil, err
+ }
+ return &roomRecentEventsUpdater{
+ transaction{ctx, nil}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
+ }, nil
+}
+
+// GetTransactionEventID implements input.EventDatabase
+func (d *Database) GetTransactionEventID(
+ ctx context.Context, transactionID string,
+ sessionID int64, userID string,
+) (string, error) {
+ eventID, err := d.statements.selectTransactionEventID(ctx, nil, transactionID, sessionID, userID)
+ if err == sql.ErrNoRows {
+ return "", nil
+ }
+ return eventID, err
+}
+
+type roomRecentEventsUpdater struct {
+ transaction
+ d *Database
+ roomNID types.RoomNID
+ latestEvents []types.StateAtEventAndReference
+ lastEventIDSent string
+ currentStateSnapshotNID types.StateSnapshotNID
+}
+
+// LatestEvents implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
+ return u.latestEvents
+}
+
+// LastEventIDSent implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) LastEventIDSent() string {
+ return u.lastEventIDSent
+}
+
+// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
+ return u.currentStateSnapshotNID
+}
+
+// StorePreviousEvents implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
+ err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ for _, ref := range previousEventReferences {
+ if err := u.d.statements.insertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ return err
+}
+
+// IsReferenced implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (res bool, err error) {
+ err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ err := u.d.statements.selectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256)
+ if err == nil {
+ res = true
+ err = nil
+ }
+ if err == sql.ErrNoRows {
+ res = false
+ err = nil
+ }
+ return err
+ })
+ return
+}
+
+// SetLatestEvents implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) SetLatestEvents(
+ roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
+ currentStateSnapshotNID types.StateSnapshotNID,
+) error {
+ err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ eventNIDs := make([]types.EventNID, len(latest))
+ for i := range latest {
+ eventNIDs[i] = latest[i].EventNID
+ }
+ return u.d.statements.updateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
+ })
+ return err
+}
+
+// HasEventBeenSent implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (res bool, err error) {
+ err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ res, err = u.d.statements.selectEventSentToOutput(u.ctx, txn, eventNID)
+ return err
+ })
+ return
+}
+
+// MarkEventAsSent implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
+ err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ return u.d.statements.updateEventSentToOutput(u.ctx, txn, eventNID)
+ })
+ return err
+}
+
+func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) {
+ err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID)
+ return err
+ })
+ return
+}
+
+// RoomNID implements query.RoomserverQueryAPIDB
+func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
+ if err == sql.ErrNoRows {
+ roomNID = 0
+ err = nil
+ }
+ return err
+ })
+ return
+}
+
+// LatestEventIDs implements query.RoomserverQueryAPIDatabase
+func (d *Database) LatestEventIDs(
+ ctx context.Context, roomNID types.RoomNID,
+) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var eventNIDs []types.EventNID
+ eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, txn, roomNID)
+ if err != nil {
+ return err
+ }
+ references, err = d.statements.bulkSelectEventReference(ctx, txn, eventNIDs)
+ if err != nil {
+ return err
+ }
+ depth, err = d.statements.selectMaxEventDepth(ctx, txn, eventNIDs)
+ if err != nil {
+ return err
+ }
+ return nil
+ })
+ return
+}
+
+// GetInvitesForUser implements query.RoomserverQueryAPIDatabase
+func (d *Database) GetInvitesForUser(
+ ctx context.Context,
+ roomNID types.RoomNID,
+ targetUserNID types.EventStateKeyNID,
+) (senderUserIDs []types.EventStateKeyNID, err error) {
+ return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
+}
+
+// SetRoomAlias implements alias.RoomserverAliasAPIDB
+func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
+ return d.statements.insertRoomAlias(ctx, nil, alias, roomID, creatorUserID)
+}
+
+// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB
+func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
+ return d.statements.selectRoomIDFromAlias(ctx, nil, alias)
+}
+
+// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB
+func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
+ return d.statements.selectAliasesFromRoomID(ctx, nil, roomID)
+}
+
+// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB
+func (d *Database) GetCreatorIDForAlias(
+ ctx context.Context, alias string,
+) (string, error) {
+ return d.statements.selectCreatorIDFromAlias(ctx, nil, alias)
+}
+
+// RemoveRoomAlias implements alias.RoomserverAliasAPIDB
+func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
+ return d.statements.deleteRoomAlias(ctx, nil, alias)
+}
+
+// StateEntriesForTuples implements state.RoomStateDatabase
+func (d *Database) StateEntriesForTuples(
+ ctx context.Context,
+ stateBlockNIDs []types.StateBlockNID,
+ stateKeyTuples []types.StateKeyTuple,
+) ([]types.StateEntryList, error) {
+ return d.statements.bulkSelectFilteredStateBlockEntries(
+ ctx, nil, stateBlockNIDs, stateKeyTuples,
+ )
+}
+
+// MembershipUpdater implements input.RoomEventDatabase
+func (d *Database) MembershipUpdater(
+ ctx context.Context, roomID, targetUserID string,
+) (types.MembershipUpdater, error) {
+ txn, err := d.db.Begin()
+ if err != nil {
+ return nil, err
+ }
+ succeeded := false
+ defer func() {
+ if !succeeded {
+ txn.Rollback() // nolint: errcheck
+ }
+ }()
+
+ roomNID, err := d.assignRoomNID(ctx, txn, roomID)
+ if err != nil {
+ return nil, err
+ }
+
+ targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID)
+ if err != nil {
+ return nil, err
+ }
+
+ updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
+ if err != nil {
+ return nil, err
+ }
+
+ succeeded = true
+ return updater, nil
+}
+
+type membershipUpdater struct {
+ transaction
+ d *Database
+ roomNID types.RoomNID
+ targetUserNID types.EventStateKeyNID
+ membership membershipState
+}
+
+func (d *Database) membershipUpdaterTxn(
+ ctx context.Context,
+ txn *sql.Tx,
+ roomNID types.RoomNID,
+ targetUserNID types.EventStateKeyNID,
+) (types.MembershipUpdater, error) {
+
+ if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
+ return nil, err
+ }
+
+ membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
+ if err != nil {
+ return nil, err
+ }
+
+ return &membershipUpdater{
+ transaction{ctx, txn}, d, roomNID, targetUserNID, membership,
+ }, nil
+}
+
+// IsInvite implements types.MembershipUpdater
+func (u *membershipUpdater) IsInvite() bool {
+ return u.membership == membershipStateInvite
+}
+
+// IsJoin implements types.MembershipUpdater
+func (u *membershipUpdater) IsJoin() bool {
+ return u.membership == membershipStateJoin
+}
+
+// IsLeave implements types.MembershipUpdater
+func (u *membershipUpdater) IsLeave() bool {
+ return u.membership == membershipStateLeaveOrBan
+}
+
+// SetToInvite implements types.MembershipUpdater
+func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted bool, err error) {
+ err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, event.Sender())
+ if err != nil {
+ return err
+ }
+ inserted, err = u.d.statements.insertInviteEvent(
+ u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
+ )
+ if err != nil {
+ return err
+ }
+ if u.membership != membershipStateInvite {
+ if err = u.d.statements.updateMembership(
+ u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
+ ); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ return
+}
+
+// SetToJoin implements types.MembershipUpdater
+func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error) {
+ err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID)
+ if err != nil {
+ return err
+ }
+
+ // If this is a join event update, there is no invite to update
+ if !isUpdate {
+ inviteEventIDs, err = u.d.statements.updateInviteRetired(
+ u.ctx, txn, u.roomNID, u.targetUserNID,
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ // Look up the NID of the new join event
+ nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
+ if err != nil {
+ return err
+ }
+
+ if u.membership != membershipStateJoin || isUpdate {
+ if err = u.d.statements.updateMembership(
+ u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
+ membershipStateJoin, nIDs[eventID],
+ ); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+
+ return
+}
+
+// SetToLeave implements types.MembershipUpdater
+func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) {
+ err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID)
+ if err != nil {
+ return err
+ }
+ inviteEventIDs, err = u.d.statements.updateInviteRetired(
+ u.ctx, txn, u.roomNID, u.targetUserNID,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Look up the NID of the new leave event
+ nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
+ if err != nil {
+ return err
+ }
+
+ if u.membership != membershipStateLeaveOrBan {
+ if err = u.d.statements.updateMembership(
+ u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
+ membershipStateLeaveOrBan, nIDs[eventID],
+ ); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ return
+}
+
+// GetMembership implements query.RoomserverQueryAPIDB
+func (d *Database) GetMembership(
+ ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
+) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ requestSenderUserNID, err := d.assignStateKeyNID(ctx, txn, requestSenderUserID)
+ if err != nil {
+ return err
+ }
+
+ membershipEventNID, _, err =
+ d.statements.selectMembershipFromRoomAndTarget(
+ ctx, txn, roomNID, requestSenderUserNID,
+ )
+ if err == sql.ErrNoRows {
+ // The user has never been a member of that room
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ stillInRoom = true
+ return nil
+ })
+
+ return
+}
+
+// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
+func (d *Database) GetMembershipEventNIDsForRoom(
+ ctx context.Context, roomNID types.RoomNID, joinOnly bool,
+) (eventNIDs []types.EventNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if joinOnly {
+ eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership(
+ ctx, txn, roomNID, membershipStateJoin,
+ )
+ return nil
+ }
+
+ eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID)
+ return nil
+ })
+ return
+}
+
+// EventsFromIDs implements query.RoomserverQueryAPIEventDB
+func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
+ nidMap, err := d.EventNIDs(ctx, eventIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ var nids []types.EventNID
+ for _, nid := range nidMap {
+ nids = append(nids, nid)
+ }
+
+ return d.Events(ctx, nids)
+}
+
+func (d *Database) GetRoomVersionForRoom(
+ ctx context.Context, roomNID types.RoomNID,
+) (int64, error) {
+ return d.statements.selectRoomVersionForRoomNID(
+ ctx, nil, roomNID,
+ )
+}
+
+type transaction struct {
+ ctx context.Context
+ txn *sql.Tx
+}
+
+// Commit implements types.Transaction
+func (t *transaction) Commit() error {
+ if t.txn == nil {
+ return nil
+ }
+ return t.txn.Commit()
+}
+
+// Rollback implements types.Transaction
+func (t *transaction) Rollback() error {
+ if t.txn == nil {
+ return nil
+ }
+ return t.txn.Rollback()
+}