diff options
author | Kegsay <kegan@matrix.org> | 2020-02-13 17:27:33 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-02-13 17:27:33 +0000 |
commit | b6ea1bc67ab51667b9e139dd05e0778aca025501 (patch) | |
tree | 18569c317fd28544144c320ce844d93a8ff8ec5e /roomserver/storage/sqlite3/storage.go | |
parent | 6942ee1de0250235164cf0ce45570b7fc919669d (diff) |
Support sqlite in addition to postgres (#869)
* Move current work into single branch
* Initial massaging of clientapi etc (not working yet)
* Interfaces for accounts/devices databases
* Duplicate postgres package for sqlite3 (no changes made to it yet)
* Some keydb, accountdb, devicedb, common partition fixes, some more syncapi tweaking
* Fix accounts DB, device DB
* Update naffka dependency for SQLite
* Naffka SQLite
* Update naffka to latest master
* SQLite support for federationsender
* Mostly not-bad support for SQLite in syncapi (although there are problems where lots of events get classed incorrectly as backward extremities, probably because of IN/ANY clauses that are badly supported)
* Update Dockerfile -> Go 1.13.7, add build-base (as gcc and friends are needed for SQLite)
* Implement GET endpoints for account_data in clientapi
* Nuke filtering for now...
* Revert "Implement GET endpoints for account_data in clientapi"
This reverts commit 4d80dff4583d278620d9b3ed437e9fcd8d4674ee.
* Implement GET endpoints for account_data in clientapi (#861)
* Implement GET endpoints for account_data in clientapi
* Fix accountDB parameter
* Remove fmt.Println
* Fix insertAccountData SQLite query
* Fix accountDB storage interfaces
* Add empty push rules into account data on account creation (#862)
* Put SaveAccountData into the right function this time
* Not sure if roomserver is better or worse now
* sqlite work
* Allow empty last sent ID for the first event
* sqlite: room creation works
* Support sending messages
* Nuke fmt.println
* Move QueryVariadic etc into common, other device fixes
* Fix some linter issues
* Fix bugs
* Fix some linting errors
* Fix errcheck lint errors
* Make naffka use postgres as fallback, fix couple of compile errors
* What on earth happened to the /rooms/{roomID}/send/{eventType} routing
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
Diffstat (limited to 'roomserver/storage/sqlite3/storage.go')
-rw-r--r-- | roomserver/storage/sqlite3/storage.go | 864 |
1 files changed, 864 insertions, 0 deletions
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go new file mode 100644 index 00000000..e20e8aed --- /dev/null +++ b/roomserver/storage/sqlite3/storage.go @@ -0,0 +1,864 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "net/url" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" + _ "github.com/mattn/go-sqlite3" +) + +// A Database is used to store room events and stream offsets. +type Database struct { + statements statements + db *sql.DB +} + +// Open a postgres database. +func Open(dataSourceName string) (*Database, error) { + var d Database + uri, err := url.Parse(dataSourceName) + if err != nil { + return nil, err + } + var cs string + if uri.Opaque != "" { // file:filename.db + cs = uri.Opaque + } else if uri.Path != "" { // file:///path/to/filename.db + cs = uri.Path + } else { + return nil, errors.New("no filename or path in connect string") + } + if d.db, err = sql.Open("sqlite3", cs); err != nil { + return nil, err + } + //d.db.Exec("PRAGMA journal_mode=WAL;") + //d.db.Exec("PRAGMA read_uncommitted = true;") + d.db.SetMaxOpenConns(2) + if err = d.statements.prepare(d.db); err != nil { + return nil, err + } + return &d, nil +} + +// StoreEvent implements input.EventDatabase +func (d *Database) StoreEvent( + ctx context.Context, event gomatrixserverlib.Event, + txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, +) (types.RoomNID, types.StateAtEvent, error) { + var ( + roomNID types.RoomNID + eventTypeNID types.EventTypeNID + eventStateKeyNID types.EventStateKeyNID + eventNID types.EventNID + stateNID types.StateSnapshotNID + err error + ) + + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + if txnAndSessionID != nil { + if err = d.statements.insertTransaction( + ctx, txn, txnAndSessionID.TransactionID, + txnAndSessionID.SessionID, event.Sender(), event.EventID(), + ); err != nil { + return err + } + } + + if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID()); err != nil { + return err + } + + if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil { + return err + } + + eventStateKey := event.StateKey() + // Assigned a numeric ID for the state_key if there is one present. + // Otherwise set the numeric ID for the state_key to 0. + if eventStateKey != nil { + if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil { + return err + } + } + + if eventNID, stateNID, err = d.statements.insertEvent( + ctx, + txn, + roomNID, + eventTypeNID, + eventStateKeyNID, + event.EventID(), + event.EventReference().EventSHA256, + authEventNIDs, + event.Depth(), + ); err != nil { + if err == sql.ErrNoRows { + // We've already inserted the event so select the numeric event ID + eventNID, stateNID, err = d.statements.selectEvent(ctx, txn, event.EventID()) + } + if err != nil { + return err + } + } + + if err = d.statements.insertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { + return err + } + + return nil + }) + if err != nil { + return 0, types.StateAtEvent{}, err + } + + return roomNID, types.StateAtEvent{ + BeforeStateSnapshotNID: stateNID, + StateEntry: types.StateEntry{ + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: eventTypeNID, + EventStateKeyNID: eventStateKeyNID, + }, + EventNID: eventNID, + }, + }, nil +} + +func (d *Database) assignRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (roomNID types.RoomNID, err error) { + // Check if we already have a numeric ID in the database. + roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID) + if err == nil { + // Now get the numeric ID back out of the database + roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) + } + } + return +} + +func (d *Database) assignEventTypeNID( + ctx context.Context, txn *sql.Tx, eventType string, +) (eventTypeNID types.EventTypeNID, err error) { + // Check if we already have a numeric ID in the database. + eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventTypeNID, err = d.statements.insertEventTypeNID(ctx, txn, eventType) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType) + } + } + return +} + +func (d *Database) assignStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (eventStateKeyNID types.EventStateKeyNID, err error) { + // Check if we already have a numeric ID in the database. + eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + } + } + return +} + +// StateEntriesForEventIDs implements input.EventDatabase +func (d *Database) StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, +) (se []types.StateEntry, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + se, err = d.statements.bulkSelectStateEventByID(ctx, txn, eventIDs) + return err + }) + return +} + +// EventTypeNIDs implements state.RoomStateDatabase +func (d *Database) EventTypeNIDs( + ctx context.Context, eventTypes []string, +) (etnids map[string]types.EventTypeNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + etnids, err = d.statements.bulkSelectEventTypeNID(ctx, txn, eventTypes) + return err + }) + return +} + +// EventStateKeyNIDs implements state.RoomStateDatabase +func (d *Database) EventStateKeyNIDs( + ctx context.Context, eventStateKeys []string, +) (esknids map[string]types.EventStateKeyNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + esknids, err = d.statements.bulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) + return err + }) + return +} + +// EventStateKeys implements query.RoomserverQueryAPIDatabase +func (d *Database) EventStateKeys( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, +) (out map[types.EventStateKeyNID]string, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + out, err = d.statements.bulkSelectEventStateKey(ctx, txn, eventStateKeyNIDs) + return err + }) + return +} + +// EventNIDs implements query.RoomserverQueryAPIDatabase +func (d *Database) EventNIDs( + ctx context.Context, eventIDs []string, +) (out map[string]types.EventNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + out, err = d.statements.bulkSelectEventNID(ctx, txn, eventIDs) + return err + }) + return +} + +// Events implements input.EventDatabase +func (d *Database) Events( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.Event, error) { + var eventJSONs []eventJSONPair + var err error + results := make([]types.Event, len(eventNIDs)) + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs) + if err != nil || len(eventJSONs) == 0 { + return nil + } + for i, eventJSON := range eventJSONs { + result := &results[i] + result.EventNID = eventJSON.EventNID + // TODO: Use NewEventFromTrustedJSON for efficiency + result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON) + if err != nil { + return nil + } + } + return nil + }) + if err != nil { + return []types.Event{}, err + } + return results, nil +} + +// AddState implements input.EventDatabase +func (d *Database) AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, +) (stateNID types.StateSnapshotNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + if len(state) > 0 { + var stateBlockNID types.StateBlockNID + stateBlockNID, err = d.statements.selectNextStateBlockNID(ctx, txn) + if err != nil { + return err + } + if err = d.statements.bulkInsertStateData(ctx, txn, stateBlockNID, state); err != nil { + return err + } + stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) + } + stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs) + return err + }) + if err != nil { + return 0, err + } + return +} + +// SetState implements input.EventDatabase +func (d *Database) SetState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + e := common.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.statements.updateEventState(ctx, txn, eventNID, stateNID) + }) + return e +} + +// StateAtEventIDs implements input.EventDatabase +func (d *Database) StateAtEventIDs( + ctx context.Context, eventIDs []string, +) (se []types.StateAtEvent, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + se, err = d.statements.bulkSelectStateAtEventByID(ctx, txn, eventIDs) + return err + }) + return +} + +// StateBlockNIDs implements state.RoomStateDatabase +func (d *Database) StateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) (sl []types.StateBlockNIDList, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + sl, err = d.statements.bulkSelectStateBlockNIDs(ctx, txn, stateNIDs) + return err + }) + return +} + +// StateEntries implements state.RoomStateDatabase +func (d *Database) StateEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) (sel []types.StateEntryList, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + sel, err = d.statements.bulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs) + return err + }) + return +} + +// SnapshotNIDFromEventID implements state.RoomStateDatabase +func (d *Database) SnapshotNIDFromEventID( + ctx context.Context, eventID string, +) (stateNID types.StateSnapshotNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + _, stateNID, err = d.statements.selectEvent(ctx, txn, eventID) + return err + }) + return +} + +// EventIDs implements input.RoomEventDatabase +func (d *Database) EventIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (out map[types.EventNID]string, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + out, err = d.statements.bulkSelectEventID(ctx, txn, eventNIDs) + return err + }) + return +} + +// GetLatestEventsForUpdate implements input.EventDatabase +func (d *Database) GetLatestEventsForUpdate( + ctx context.Context, roomNID types.RoomNID, +) (types.RoomRecentEventsUpdater, error) { + txn, err := d.db.Begin() + if err != nil { + return nil, err + } + eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := + d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + var lastEventIDSent string + if lastEventNIDSent != 0 { + lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + } + + // FIXME: we probably want to support long-lived txns in sqlite somehow, but we don't because we get + // 'database is locked' errors caused by multiple write txns (one being the long-lived txn created here) + // so for now let's not use a long-lived txn at all, and just commit it here and set the txn to nil so + // we fail fast if someone tries to use the underlying txn object. + err = txn.Commit() + if err != nil { + return nil, err + } + return &roomRecentEventsUpdater{ + transaction{ctx, nil}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + }, nil +} + +// GetTransactionEventID implements input.EventDatabase +func (d *Database) GetTransactionEventID( + ctx context.Context, transactionID string, + sessionID int64, userID string, +) (string, error) { + eventID, err := d.statements.selectTransactionEventID(ctx, nil, transactionID, sessionID, userID) + if err == sql.ErrNoRows { + return "", nil + } + return eventID, err +} + +type roomRecentEventsUpdater struct { + transaction + d *Database + roomNID types.RoomNID + latestEvents []types.StateAtEventAndReference + lastEventIDSent string + currentStateSnapshotNID types.StateSnapshotNID +} + +// LatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { + return u.latestEvents +} + +// LastEventIDSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LastEventIDSent() string { + return u.lastEventIDSent +} + +// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { + return u.currentStateSnapshotNID +} + +// StorePreviousEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { + err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + for _, ref := range previousEventReferences { + if err := u.d.statements.insertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return err + } + } + return nil + }) + return err +} + +// IsReferenced implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (res bool, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + err := u.d.statements.selectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256) + if err == nil { + res = true + err = nil + } + if err == sql.ErrNoRows { + res = false + err = nil + } + return err + }) + return +} + +// SetLatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) SetLatestEvents( + roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, + currentStateSnapshotNID types.StateSnapshotNID, +) error { + err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + eventNIDs := make([]types.EventNID, len(latest)) + for i := range latest { + eventNIDs[i] = latest[i].EventNID + } + return u.d.statements.updateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) + }) + return err +} + +// HasEventBeenSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (res bool, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + res, err = u.d.statements.selectEventSentToOutput(u.ctx, txn, eventNID) + return err + }) + return +} + +// MarkEventAsSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { + err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + return u.d.statements.updateEventSentToOutput(u.ctx, txn, eventNID) + }) + return err +} + +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID) + return err + }) + return +} + +// RoomNID implements query.RoomserverQueryAPIDB +func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + roomNID = 0 + err = nil + } + return err + }) + return +} + +// LatestEventIDs implements query.RoomserverQueryAPIDatabase +func (d *Database) LatestEventIDs( + ctx context.Context, roomNID types.RoomNID, +) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + var eventNIDs []types.EventNID + eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, txn, roomNID) + if err != nil { + return err + } + references, err = d.statements.bulkSelectEventReference(ctx, txn, eventNIDs) + if err != nil { + return err + } + depth, err = d.statements.selectMaxEventDepth(ctx, txn, eventNIDs) + if err != nil { + return err + } + return nil + }) + return +} + +// GetInvitesForUser implements query.RoomserverQueryAPIDatabase +func (d *Database) GetInvitesForUser( + ctx context.Context, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (senderUserIDs []types.EventStateKeyNID, err error) { + return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) +} + +// SetRoomAlias implements alias.RoomserverAliasAPIDB +func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { + return d.statements.insertRoomAlias(ctx, nil, alias, roomID, creatorUserID) +} + +// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB +func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { + return d.statements.selectRoomIDFromAlias(ctx, nil, alias) +} + +// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB +func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { + return d.statements.selectAliasesFromRoomID(ctx, nil, roomID) +} + +// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB +func (d *Database) GetCreatorIDForAlias( + ctx context.Context, alias string, +) (string, error) { + return d.statements.selectCreatorIDFromAlias(ctx, nil, alias) +} + +// RemoveRoomAlias implements alias.RoomserverAliasAPIDB +func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { + return d.statements.deleteRoomAlias(ctx, nil, alias) +} + +// StateEntriesForTuples implements state.RoomStateDatabase +func (d *Database) StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return d.statements.bulkSelectFilteredStateBlockEntries( + ctx, nil, stateBlockNIDs, stateKeyTuples, + ) +} + +// MembershipUpdater implements input.RoomEventDatabase +func (d *Database) MembershipUpdater( + ctx context.Context, roomID, targetUserID string, +) (types.MembershipUpdater, error) { + txn, err := d.db.Begin() + if err != nil { + return nil, err + } + succeeded := false + defer func() { + if !succeeded { + txn.Rollback() // nolint: errcheck + } + }() + + roomNID, err := d.assignRoomNID(ctx, txn, roomID) + if err != nil { + return nil, err + } + + targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) + if err != nil { + return nil, err + } + + updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) + if err != nil { + return nil, err + } + + succeeded = true + return updater, nil +} + +type membershipUpdater struct { + transaction + d *Database + roomNID types.RoomNID + targetUserNID types.EventStateKeyNID + membership membershipState +} + +func (d *Database) membershipUpdaterTxn( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (types.MembershipUpdater, error) { + + if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { + return nil, err + } + + membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) + if err != nil { + return nil, err + } + + return &membershipUpdater{ + transaction{ctx, txn}, d, roomNID, targetUserNID, membership, + }, nil +} + +// IsInvite implements types.MembershipUpdater +func (u *membershipUpdater) IsInvite() bool { + return u.membership == membershipStateInvite +} + +// IsJoin implements types.MembershipUpdater +func (u *membershipUpdater) IsJoin() bool { + return u.membership == membershipStateJoin +} + +// IsLeave implements types.MembershipUpdater +func (u *membershipUpdater) IsLeave() bool { + return u.membership == membershipStateLeaveOrBan +} + +// SetToInvite implements types.MembershipUpdater +func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted bool, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, event.Sender()) + if err != nil { + return err + } + inserted, err = u.d.statements.insertInviteEvent( + u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), + ) + if err != nil { + return err + } + if u.membership != membershipStateInvite { + if err = u.d.statements.updateMembership( + u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, + ); err != nil { + return err + } + } + return nil + }) + return +} + +// SetToJoin implements types.MembershipUpdater +func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID) + if err != nil { + return err + } + + // If this is a join event update, there is no invite to update + if !isUpdate { + inviteEventIDs, err = u.d.statements.updateInviteRetired( + u.ctx, txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return err + } + } + + // Look up the NID of the new join event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return err + } + + if u.membership != membershipStateJoin || isUpdate { + if err = u.d.statements.updateMembership( + u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateJoin, nIDs[eventID], + ); err != nil { + return err + } + } + return nil + }) + + return +} + +// SetToLeave implements types.MembershipUpdater +func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) { + err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID) + if err != nil { + return err + } + inviteEventIDs, err = u.d.statements.updateInviteRetired( + u.ctx, txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return err + } + + // Look up the NID of the new leave event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return err + } + + if u.membership != membershipStateLeaveOrBan { + if err = u.d.statements.updateMembership( + u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateLeaveOrBan, nIDs[eventID], + ); err != nil { + return err + } + } + return nil + }) + return +} + +// GetMembership implements query.RoomserverQueryAPIDB +func (d *Database) GetMembership( + ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, +) (membershipEventNID types.EventNID, stillInRoom bool, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + requestSenderUserNID, err := d.assignStateKeyNID(ctx, txn, requestSenderUserID) + if err != nil { + return err + } + + membershipEventNID, _, err = + d.statements.selectMembershipFromRoomAndTarget( + ctx, txn, roomNID, requestSenderUserNID, + ) + if err == sql.ErrNoRows { + // The user has never been a member of that room + return nil + } + if err != nil { + return err + } + stillInRoom = true + return nil + }) + + return +} + +// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB +func (d *Database) GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, +) (eventNIDs []types.EventNID, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + if joinOnly { + eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership( + ctx, txn, roomNID, membershipStateJoin, + ) + return nil + } + + eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID) + return nil + }) + return +} + +// EventsFromIDs implements query.RoomserverQueryAPIEventDB +func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + nidMap, err := d.EventNIDs(ctx, eventIDs) + if err != nil { + return nil, err + } + + var nids []types.EventNID + for _, nid := range nidMap { + nids = append(nids, nid) + } + + return d.Events(ctx, nids) +} + +func (d *Database) GetRoomVersionForRoom( + ctx context.Context, roomNID types.RoomNID, +) (int64, error) { + return d.statements.selectRoomVersionForRoomNID( + ctx, nil, roomNID, + ) +} + +type transaction struct { + ctx context.Context + txn *sql.Tx +} + +// Commit implements types.Transaction +func (t *transaction) Commit() error { + if t.txn == nil { + return nil + } + return t.txn.Commit() +} + +// Rollback implements types.Transaction +func (t *transaction) Rollback() error { + if t.txn == nil { + return nil + } + return t.txn.Rollback() +} |