diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2020-08-19 15:38:27 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-19 15:38:27 +0100 |
commit | b24747b305a0770fdd746655e702aa1c1c049765 (patch) | |
tree | 88d94b762fafb4852421eb243313edbfc96ccfa9 /roomserver | |
parent | 775b04d776ddc06fdee5ece6a407008f00edb7f2 (diff) |
Transaction writer changes, move roomserver writers (#1285)
* Updated TransactionWriters, moved locks in roomserver, various other tweaks
* Fix redaction deadlocks
* Fix lint issue
* Rename SQLiteTransactionWriter to ExclusiveTransactionWriter
* Fix us not sending transactions through in latest events updater
Diffstat (limited to 'roomserver')
21 files changed, 296 insertions, 324 deletions
diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go index 0158c8f7..3be5218d 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input_latest_events.go @@ -57,7 +57,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( ) (err error) { updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID) if err != nil { - return + return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) } succeeded := false defer func() { @@ -79,7 +79,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( } if err = u.doUpdateLatestEvents(); err != nil { - return err + return fmt.Errorf("u.doUpdateLatestEvents: %w", err) } succeeded = true @@ -137,7 +137,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // don't need to do anything, as we've handled it already. hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID) if err != nil { - return err + return fmt.Errorf("u.updater.HasEventBeenSent: %w", err) } else if hasBeenSent { return nil } @@ -145,7 +145,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // Update the roomserver_previous_events table with references. This // is effectively tracking the structure of the DAG. if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil { - return err + return fmt.Errorf("u.updater.StorePreviousEvents: %w", err) } // Get the event reference for our new event. This will be used when @@ -156,7 +156,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // in the room. If it is then it isn't a latest event. alreadyReferenced, err := u.updater.IsReferenced(eventReference) if err != nil { - return err + return fmt.Errorf("u.updater.IsReferenced: %w", err) } // Work out what the latest events are. @@ -173,19 +173,19 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // Now that we know what the latest events are, it's time to get the // latest state. if err = u.latestState(); err != nil { - return err + return fmt.Errorf("u.latestState: %w", err) } // If we need to generate any output events then here's where we do it. // TODO: Move this! updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added) if err != nil { - return err + return fmt.Errorf("u.api.updateMemberships: %w", err) } update, err := u.makeOutputNewRoomEvent() if err != nil { - return err + return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err) } updates = append(updates, *update) @@ -198,14 +198,18 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { // the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the // necessary bookkeeping we'll keep the event sending synchronous for now. if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil { - return err + return fmt.Errorf("u.api.WriteOutputEvents: %w", err) } if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil { - return err + return fmt.Errorf("u.updater.SetLatestEvents: %w", err) } - return u.updater.MarkEventAsSent(u.stateAtEvent.EventNID) + if err = u.updater.MarkEventAsSent(u.stateAtEvent.EventNID); err != nil { + return fmt.Errorf("u.updater.MarkEventAsSent: %w", err) + } + + return nil } func (u *latestEventsUpdater) latestState() error { @@ -225,7 +229,7 @@ func (u *latestEventsUpdater) latestState() error { u.ctx, u.roomNID, latestStateAtEvents, ) if err != nil { - return err + return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err) } // If we are overwriting the state then we should make sure that we @@ -244,7 +248,7 @@ func (u *latestEventsUpdater) latestState() error { u.ctx, u.oldStateNID, u.newStateNID, ) if err != nil { - return err + return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err) } // Also work out the state before the event removes and the event @@ -252,7 +256,11 @@ func (u *latestEventsUpdater) latestState() error { u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots( u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID, ) - return err + if err != nil { + return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err) + } + + return nil } func calculateLatest( diff --git a/roomserver/state/state.go b/roomserver/state/state.go index d5be4a90..b9ad4a50 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -558,7 +558,11 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( // 2) There weren't any prev_events for this event so the state is // empty. metrics.algorithm = "empty_state" - return metrics.stop(v.db.AddState(ctx, roomNID, nil, nil)) + stateNID, err := v.db.AddState(ctx, roomNID, nil, nil) + if err != nil { + err = fmt.Errorf("v.db.AddState: %w", err) + } + return metrics.stop(stateNID, err) } if len(prevStates) == 1 { @@ -578,22 +582,30 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( ) if err != nil { metrics.algorithm = "_load_state_blocks" - return metrics.stop(0, err) + return metrics.stop(0, fmt.Errorf("v.db.StateBlockNIDs: %w", err)) } stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs if len(stateBlockNIDs) < maxStateBlockNIDs { // 4) The number of state data blocks is small enough that we can just // add the state event as a block of size one to the end of the blocks. metrics.algorithm = "single_delta" - return metrics.stop(v.db.AddState( + stateNID, err := v.db.AddState( ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, - )) + ) + if err != nil { + err = fmt.Errorf("v.db.AddState: %w", err) + } + return metrics.stop(stateNID, err) } // If there are too many deltas then we need to calculate the full state // So fall through to calculateAndStoreStateAfterManyEvents } - return v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics) + stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics) + if err != nil { + return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err) + } + return stateNID, nil } // maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state. diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 52ff479b..0b7ed225 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -98,6 +98,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { } d.Database = shared.Database{ DB: db, + Writer: sqlutil.NewDummyTransactionWriter(), EventTypesTable: eventTypes, EventStateKeysTable: eventStateKeys, EventJSONTable: eventJSON, diff --git a/roomserver/storage/shared/latest_events_updater.go b/roomserver/storage/shared/latest_events_updater.go index 21b168a4..e9a0f698 100644 --- a/roomserver/storage/shared/latest_events_updater.go +++ b/roomserver/storage/shared/latest_events_updater.go @@ -3,6 +3,7 @@ package shared import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -65,12 +66,14 @@ func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { // StorePreviousEvents implements types.RoomRecentEventsUpdater func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - for _, ref := range previousEventReferences { - if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return err + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + for _, ref := range previousEventReferences { + if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) + } } - } - return nil + return nil + }) } // IsReferenced implements types.RoomRecentEventsUpdater @@ -82,7 +85,7 @@ func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.Even if err == sql.ErrNoRows { return false, nil } - return false, err + return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err) } // SetLatestEvents implements types.RoomRecentEventsUpdater @@ -94,7 +97,12 @@ func (u *LatestEventsUpdater) SetLatestEvents( for i := range latest { eventNIDs[i] = latest[i].EventNID } - return u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil { + return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err) + } + return nil + }) } // HasEventBeenSent implements types.RoomRecentEventsUpdater @@ -104,7 +112,9 @@ func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, e // MarkEventAsSent implements types.RoomRecentEventsUpdater func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, u.txn, eventNID) + return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { + return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID) + }) } func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 5955844f..329813bf 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -3,6 +3,7 @@ package shared import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" @@ -41,9 +42,14 @@ func (d *Database) membershipUpdaterTxn( targetUserNID types.EventStateKeyNID, targetLocal bool, ) (*MembershipUpdater, error) { - - if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { - return nil, err + err := d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { + if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil { + return fmt.Errorf("d.MembershipTable.InsertMembership: %w", err) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("u.d.Writer.Do: %w", err) } membership, err := d.MembershipTable.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) @@ -75,19 +81,19 @@ func (u *MembershipUpdater) IsLeave() bool { func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) if err != nil { - return false, err + return false, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } inserted, err := u.d.InvitesTable.InsertInviteEvent( u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), ) if err != nil { - return false, err + return false, fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) } if u.membership != tables.MembershipStateInvite { if err = u.d.MembershipTable.UpdateMembership( u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, ); err != nil { - return false, err + return false, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } return inserted, nil @@ -99,7 +105,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } // If this is a join event update, there is no invite to update @@ -108,14 +114,14 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err) } } // Look up the NID of the new join event nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.EventNIDs: %w", err) } if u.membership != tables.MembershipStateJoin || isUpdate { @@ -123,7 +129,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], ); err != nil { - return nil, err + return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -134,19 +140,19 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } inviteEventIDs, err := u.d.InvitesTable.UpdateInviteRetired( u.ctx, u.txn, u.roomNID, u.targetUserNID, ) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err) } // Look up the NID of the new leave event nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) if err != nil { - return nil, err + return nil, fmt.Errorf("u.d.EventNIDs: %w", err) } if u.membership != tables.MembershipStateLeaveOrBan { @@ -154,7 +160,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], ); err != nil { - return nil, err + return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } return inviteEventIDs, nil diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 00179e33..45020d55 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -27,6 +27,7 @@ const redactionsArePermanent = false type Database struct { DB *sql.DB + Writer sqlutil.TransactionWriter EventsTable tables.Events EventJSONTable tables.EventJSON EventTypesTable tables.EventTypes @@ -83,20 +84,23 @@ func (d *Database) AddState( stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if len(state) > 0 { var stateBlockNID types.StateBlockNID stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state) if err != nil { - return err + return fmt.Errorf("d.StateBlockTable.BulkInsertStateData: %w", err) } stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) } stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs) - return err + if err != nil { + return fmt.Errorf("d.StateSnapshotTable.InsertState: %w", err) + } + return nil }) if err != nil { - return 0, err + return 0, fmt.Errorf("d.Writer.Do: %w", err) } return } @@ -110,7 +114,9 @@ func (d *Database) EventNIDs( func (d *Database) SetState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID) + return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID) + }) } func (d *Database) StateAtEventIDs( @@ -221,7 +227,9 @@ func (d *Database) GetRoomVersionForRoomNID( } func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { - return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID) + return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID) + }) } func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { @@ -239,15 +247,21 @@ func (d *Database) GetCreatorIDForAlias( } func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { - return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias) + return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias) + }) } func (d *Database) GetMembership( ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, ) (membershipEventNID types.EventNID, stillInRoom bool, err error) { - requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) + var requestSenderUserNID types.EventStateKeyNID + err = d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + requestSenderUserNID, err = d.assignStateKeyNID(ctx, nil, requestSenderUserID) + return err + }) if err != nil { - return + return 0, false, fmt.Errorf("d.assignStateKeyNID: %w", err) } senderMembershipEventNID, senderMembership, err := @@ -350,6 +364,7 @@ func (d *Database) GetLatestEventsForUpdate( return NewLatestEventsUpdater(ctx, d, txn, roomNID) } +// nolint:gocyclo func (d *Database) StoreEvent( ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, @@ -365,10 +380,10 @@ func (d *Database) StoreEvent( err error ) - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if txnAndSessionID != nil { if err = d.TransactionsTable.InsertTransaction( - ctx, txn, txnAndSessionID.TransactionID, + ctx, nil, txnAndSessionID.TransactionID, txnAndSessionID.SessionID, event.Sender(), event.EventID(), ); err != nil { return fmt.Errorf("d.TransactionsTable.InsertTransaction: %w", err) @@ -433,7 +448,7 @@ func (d *Database) StoreEvent( return nil }) if err != nil { - return 0, types.StateAtEvent{}, nil, "", err + return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err) } return roomNID, types.StateAtEvent{ @@ -449,7 +464,9 @@ func (d *Database) StoreEvent( } func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error { - return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish) + return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { + return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish) + }) } func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) { diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go index e8118ad7..3cd44b1d 100644 --- a/roomserver/storage/sqlite3/event_json_table.go +++ b/roomserver/storage/sqlite3/event_json_table.go @@ -49,15 +49,13 @@ const bulkSelectEventJSONSQL = ` type eventJSONStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertEventJSONStmt *sql.Stmt bulkSelectEventJSONStmt *sql.Stmt } -func NewSqliteEventJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventJSON, error) { +func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) { s := &eventJSONStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(eventJSONSchema) if err != nil { @@ -72,10 +70,8 @@ func NewSqliteEventJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tab func (s *eventJSONStatements) InsertEventJSON( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) - return err - }) + _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON) + return err } func (s *eventJSONStatements) BulkSelectEventJSON( diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go index c8ad052b..345df8c6 100644 --- a/roomserver/storage/sqlite3/event_state_keys_table.go +++ b/roomserver/storage/sqlite3/event_state_keys_table.go @@ -64,17 +64,15 @@ const bulkSelectEventStateKeyNIDSQL = ` type eventStateKeyStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertEventStateKeyNIDStmt *sql.Stmt selectEventStateKeyNIDStmt *sql.Stmt bulkSelectEventStateKeyNIDStmt *sql.Stmt bulkSelectEventStateKeyStmt *sql.Stmt } -func NewSqliteEventStateKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventStateKeys, error) { +func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) { s := &eventStateKeyStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(eventStateKeysSchema) if err != nil { @@ -91,19 +89,15 @@ func NewSqliteEventStateKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) func (s *eventStateKeyStatements) InsertEventStateKeyNID( ctx context.Context, txn *sql.Tx, eventStateKey string, ) (types.EventStateKeyNID, error) { - var eventStateKeyNID int64 - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) - res, err := insertStmt.ExecContext(ctx, eventStateKey) - if err != nil { - return err - } - eventStateKeyNID, err = res.LastInsertId() - if err != nil { - return err - } - return nil - }) + insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt) + res, err := insertStmt.ExecContext(ctx, eventStateKey) + if err != nil { + return 0, err + } + eventStateKeyNID, err := res.LastInsertId() + if err != nil { + return 0, err + } return types.EventStateKeyNID(eventStateKeyNID), err } diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go index 4a645789..26e2bf84 100644 --- a/roomserver/storage/sqlite3/event_types_table.go +++ b/roomserver/storage/sqlite3/event_types_table.go @@ -18,6 +18,7 @@ package sqlite3 import ( "context" "database/sql" + "fmt" "strings" "github.com/matrix-org/dendrite/internal" @@ -78,17 +79,15 @@ const bulkSelectEventTypeNIDSQL = ` type eventTypeStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertEventTypeNIDStmt *sql.Stmt insertEventTypeNIDResultStmt *sql.Stmt selectEventTypeNIDStmt *sql.Stmt bulkSelectEventTypeNIDStmt *sql.Stmt } -func NewSqliteEventTypesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventTypes, error) { +func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) { s := &eventTypeStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(eventTypesSchema) if err != nil { @@ -104,18 +103,18 @@ func NewSqliteEventTypesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (ta } func (s *eventTypeStatements) InsertEventTypeNID( - ctx context.Context, tx *sql.Tx, eventType string, + ctx context.Context, txn *sql.Tx, eventType string, ) (types.EventTypeNID, error) { var eventTypeNID int64 - err := s.writer.Do(s.db, tx, func(tx *sql.Tx) error { - insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt) - resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt) - _, err := insertStmt.ExecContext(ctx, eventType) - if err != nil { - return err - } - return resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID) - }) + insertStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDStmt) + resultStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDResultStmt) + _, err := insertStmt.ExecContext(ctx, eventType) + if err != nil { + return 0, fmt.Errorf("insertStmt.ExecContext: %w", err) + } + if err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID); err != nil { + return 0, fmt.Errorf("resultStmt.QueryRowContext.Scan: %w", err) + } return types.EventTypeNID(eventTypeNID), err } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 0e39755c..26ea1d41 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -99,7 +99,6 @@ const selectRoomNIDForEventNIDSQL = "" + type eventStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt @@ -115,10 +114,9 @@ type eventStatements struct { selectRoomNIDForEventNIDStmt *sql.Stmt } -func NewSqliteEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Events, error) { +func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { s := &eventStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(eventsSchema) if err != nil { @@ -155,22 +153,19 @@ func (s *eventStatements) InsertEvent( ) (types.EventNID, types.StateSnapshotNID, error) { // attempt to insert: the last_row_id is the event NID var eventNID int64 - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) - result, err := insertStmt.ExecContext( - ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), - eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, - ) - if err != nil { - return err - } - modified, err := result.RowsAffected() - if modified == 0 && err == nil { - return sql.ErrNoRows - } - eventNID, err = result.LastInsertId() - return err - }) + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + result, err := insertStmt.ExecContext( + ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), + eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, + ) + if err != nil { + return 0, 0, err + } + modified, err := result.RowsAffected() + if modified == 0 && err == nil { + return 0, 0, sql.ErrNoRows + } + eventNID, err = result.LastInsertId() return types.EventNID(eventNID), 0, err } @@ -286,11 +281,8 @@ func (s *eventStatements) BulkSelectStateAtEventByID( func (s *eventStatements) UpdateEventState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt) - _, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) - return err - }) + _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) + return err } func (s *eventStatements) SelectEventSentToOutput( @@ -302,11 +294,9 @@ func (s *eventStatements) SelectEventSentToOutput( } func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) - _, err := updateStmt.ExecContext(ctx, int64(eventNID)) - return err - }) + updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt) + _, err := updateStmt.ExecContext(ctx, int64(eventNID)) + return err } func (s *eventStatements) SelectEventID( @@ -334,7 +324,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference( rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...) if err != nil { - return nil, err + return nil, fmt.Errorf("sqlutil.TxStmt.QueryContext: %w", err) } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed") results := make([]types.StateAtEventAndReference, len(eventNIDs)) @@ -481,7 +471,7 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, } err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result) if err != nil { - return 0, err + return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err) } return result, nil } diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go index 1305f4a8..327be6a0 100644 --- a/roomserver/storage/sqlite3/invite_table.go +++ b/roomserver/storage/sqlite3/invite_table.go @@ -64,17 +64,15 @@ SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_ni type inviteStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertInviteEventStmt *sql.Stmt selectInviteActiveForUserInRoomStmt *sql.Stmt updateInviteRetiredStmt *sql.Stmt selectInvitesAboutToRetireStmt *sql.Stmt } -func NewSqliteInvitesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Invites, error) { +func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) { s := &inviteStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(inviteSchema) if err != nil { @@ -96,20 +94,17 @@ func (s *inviteStatements) InsertInviteEvent( inviteEventJSON []byte, ) (bool, error) { var count int64 - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) - result, err := stmt.ExecContext( - ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, - ) - if err != nil { - return err - } - count, err = result.RowsAffected() - if err != nil { - return err - } - return nil - }) + stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) + result, err := stmt.ExecContext( + ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON, + ) + if err != nil { + return false, err + } + count, err = result.RowsAffected() + if err != nil { + return false, err + } return count != 0, err } @@ -117,26 +112,23 @@ func (s *inviteStatements) UpdateInviteRetired( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, ) (eventIDs []string, err error) { - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - // gather all the event IDs we will retire - stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) - rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) - if err != nil { - return err - } - defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed") - for rows.Next() { - var inviteEventID string - if err = rows.Scan(&inviteEventID); err != nil { - return err - } - eventIDs = append(eventIDs, inviteEventID) + // gather all the event IDs we will retire + stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt) + rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed") + for rows.Next() { + var inviteEventID string + if err = rows.Scan(&inviteEventID); err != nil { + return } - // now retire the invites - stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) - _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) - return err - }) + eventIDs = append(eventIDs, inviteEventID) + } + // now retire the invites + stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt) + _, err = stmt.ExecContext(ctx, roomNID, targetUserNID) return } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 7b69cee3..b3ee69c0 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -77,7 +77,6 @@ const updateMembershipSQL = "" + type membershipStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt selectMembershipFromRoomAndTargetStmt *sql.Stmt @@ -88,10 +87,9 @@ type membershipStatements struct { updateMembershipStmt *sql.Stmt } -func NewSqliteMembershipTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Membership, error) { +func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { s := &membershipStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(membershipSchema) if err != nil { @@ -115,11 +113,9 @@ func (s *membershipStatements) InsertMembership( roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) - _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt) + _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget) + return err } func (s *membershipStatements) SelectMembershipForUpdate( @@ -201,11 +197,9 @@ func (s *membershipStatements) UpdateMembership( senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) - _, err := stmt.ExecContext( - ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) + _, err := stmt.ExecContext( + ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID, + ) + return err } diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index ff804861..d28a42c6 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -54,15 +54,13 @@ const selectPreviousEventExistsSQL = ` type previousEventStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertPreviousEventStmt *sql.Stmt selectPreviousEventExistsStmt *sql.Stmt } -func NewSqlitePrevEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.PreviousEvents, error) { +func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { s := &previousEventStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(previousEventSchema) if err != nil { @@ -82,13 +80,11 @@ func (s *previousEventStatements) InsertPreviousEvent( previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) - _, err := stmt.ExecContext( - ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) + _, err := stmt.ExecContext( + ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), + ) + return err } // Check if the event reference exists diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index a4a47aec..1d6ccd56 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -19,7 +19,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -45,16 +44,14 @@ const selectPublishedSQL = "" + type publishedStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter upsertPublishedStmt *sql.Stmt selectAllPublishedStmt *sql.Stmt selectPublishedStmt *sql.Stmt } -func NewSqlitePublishedTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Published, error) { +func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { s := &publishedStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(publishedSchema) if err != nil { @@ -69,12 +66,9 @@ func NewSqlitePublishedTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tab func (s *publishedStatements) UpsertRoomPublished( ctx context.Context, roomID string, published bool, -) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt) - _, err := stmt.ExecContext(ctx, roomID, published) - return err - }) +) error { + _, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published) + return err } func (s *publishedStatements) SelectPublishedFromRoomID( diff --git a/roomserver/storage/sqlite3/redactions_table.go b/roomserver/storage/sqlite3/redactions_table.go index ad900a4e..a2179357 100644 --- a/roomserver/storage/sqlite3/redactions_table.go +++ b/roomserver/storage/sqlite3/redactions_table.go @@ -53,17 +53,15 @@ const markRedactionValidatedSQL = "" + type redactionStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertRedactionStmt *sql.Stmt selectRedactionInfoByRedactionEventIDStmt *sql.Stmt selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt markRedactionValidatedStmt *sql.Stmt } -func NewSqliteRedactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Redactions, error) { +func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) { s := &redactionStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(redactionsSchema) if err != nil { @@ -81,11 +79,9 @@ func NewSqliteRedactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (ta func (s *redactionStatements) InsertRedaction( ctx context.Context, txn *sql.Tx, info tables.RedactionInfo, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt) - _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt) + _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated) + return err } func (s *redactionStatements) SelectRedactionInfoByRedactionEventID( @@ -121,9 +117,7 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted( func (s *redactionStatements) MarkRedactionValidated( ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt) - _, err := stmt.ExecContext(ctx, redactionEventID, validated) - return err - }) + stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt) + _, err := stmt.ExecContext(ctx, redactionEventID, validated) + return err } diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index deba3ff5..a16e97aa 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -57,7 +56,6 @@ const deleteRoomAliasSQL = ` type roomAliasesStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertRoomAliasStmt *sql.Stmt selectRoomIDFromAliasStmt *sql.Stmt selectAliasesFromRoomIDStmt *sql.Stmt @@ -65,10 +63,9 @@ type roomAliasesStatements struct { deleteRoomAliasStmt *sql.Stmt } -func NewSqliteRoomAliasesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.RoomAliases, error) { +func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { s := &roomAliasesStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(roomAliasesSchema) if err != nil { @@ -85,12 +82,9 @@ func NewSqliteRoomAliasesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (t func (s *roomAliasesStatements) InsertRoomAlias( ctx context.Context, alias string, roomID string, creatorUserID string, -) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt) - _, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID) - return err - }) +) error { + _, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) + return err } func (s *roomAliasesStatements) SelectRoomIDFromAlias( @@ -138,10 +132,7 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias( func (s *roomAliasesStatements) DeleteRoomAlias( ctx context.Context, alias string, -) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt) - _, err := stmt.ExecContext(ctx, alias) - return err - }) +) error { + _, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias) + return err } diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 8bbec508..6541cc0c 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -66,7 +66,6 @@ const selectRoomVersionForRoomNIDSQL = "" + type roomStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt @@ -76,10 +75,9 @@ type roomStatements struct { selectRoomVersionForRoomNIDStmt *sql.Stmt } -func NewSqliteRoomsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Rooms, error) { +func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) { s := &roomStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(roomsSchema) if err != nil { @@ -100,20 +98,14 @@ func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, ) (roomNID types.RoomNID, err error) { - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) - _, err = insertStmt.ExecContext(ctx, roomID, roomVersion) - if err != nil { - return fmt.Errorf("insertStmt.ExecContext: %w", err) - } - roomNID, err = s.SelectRoomNID(ctx, txn, roomID) - if err != nil { - return fmt.Errorf("s.SelectRoomNID: %w", err) - } - return nil - }) + insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt) + _, err = insertStmt.ExecContext(ctx, roomID, roomVersion) if err != nil { - return types.RoomNID(0), err + return 0, fmt.Errorf("insertStmt.ExecContext: %w", err) + } + roomNID, err = s.SelectRoomNID(ctx, txn, roomID) + if err != nil { + return 0, fmt.Errorf("s.SelectRoomNID: %w", err) } return } @@ -170,17 +162,15 @@ func (s *roomStatements) UpdateLatestEventNIDs( lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) - _, err := stmt.ExecContext( - ctx, - eventNIDsAsArray(eventNIDs), - int64(lastEventSentNID), - int64(stateSnapshotNID), - roomNID, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt) + _, err := stmt.ExecContext( + ctx, + eventNIDsAsArray(eventNIDs), + int64(lastEventSentNID), + int64(stateSnapshotNID), + roomNID, + ) + return err } func (s *roomStatements) SelectRoomVersionForRoomID( diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 3e28e450..8033903f 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -74,17 +74,15 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" + type stateBlockStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertStateDataStmt *sql.Stmt selectNextStateBlockNIDStmt *sql.Stmt bulkSelectStateBlockEntriesStmt *sql.Stmt bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt } -func NewSqliteStateBlockTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateBlock, error) { +func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { s := &stateBlockStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(stateDataSchema) if err != nil { @@ -107,25 +105,22 @@ func (s *stateBlockStatements) BulkInsertStateData( return 0, nil } var stateBlockNID types.StateBlockNID - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + if err != nil { + return 0, err + } + for _, entry := range entries { + _, err = txn.Stmt(s.insertStateDataStmt).ExecContext( + ctx, + int64(stateBlockNID), + int64(entry.EventTypeNID), + int64(entry.EventStateKeyNID), + int64(entry.EventNID), + ) if err != nil { - return err + return 0, err } - for _, entry := range entries { - _, err := txn.Stmt(s.insertStateDataStmt).ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) - if err != nil { - return err - } - } - return nil - }) + } return stateBlockNID, err } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 799904ff..392c2a67 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -50,15 +50,13 @@ const bulkSelectStateBlockNIDsSQL = "" + type stateSnapshotStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertStateStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt } -func NewSqliteStateSnapshotTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateSnapshot, error) { +func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { s := &stateSnapshotStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(stateSnapshotSchema) if err != nil { @@ -78,19 +76,16 @@ func (s *stateSnapshotStatements) InsertState( if err != nil { return } - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - insertStmt := txn.Stmt(s.insertStateStmt) - res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) - if err != nil { - return err - } - lastRowID, err := res.LastInsertId() - if err != nil { - return err - } - stateNID = types.StateSnapshotNID(lastRowID) - return nil - }) + insertStmt := txn.Stmt(s.insertStateStmt) + res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) + if err != nil { + return 0, err + } + lastRowID, err := res.LastInsertId() + if err != nil { + return 0, err + } + stateNID = types.StateSnapshotNID(lastRowID) return } diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 72431637..8e3af6b7 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -41,6 +41,7 @@ type Database struct { invites tables.Invites membership tables.Membership db *sql.DB + writer sqlutil.TransactionWriter } // Open a sqlite database. @@ -51,7 +52,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } - writer := sqlutil.NewTransactionWriter() + d.writer = sqlutil.NewTransactionWriter() //d.db.Exec("PRAGMA journal_mode=WAL;") //d.db.Exec("PRAGMA read_uncommitted = true;") @@ -61,64 +62,65 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { // which it will never obtain. d.db.SetMaxOpenConns(20) - d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db, writer) + d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db) if err != nil { return nil, err } - d.eventTypes, err = NewSqliteEventTypesTable(d.db, writer) + d.eventTypes, err = NewSqliteEventTypesTable(d.db) if err != nil { return nil, err } - d.eventJSON, err = NewSqliteEventJSONTable(d.db, writer) + d.eventJSON, err = NewSqliteEventJSONTable(d.db) if err != nil { return nil, err } - d.events, err = NewSqliteEventsTable(d.db, writer) + d.events, err = NewSqliteEventsTable(d.db) if err != nil { return nil, err } - d.rooms, err = NewSqliteRoomsTable(d.db, writer) + d.rooms, err = NewSqliteRoomsTable(d.db) if err != nil { return nil, err } - d.transactions, err = NewSqliteTransactionsTable(d.db, writer) + d.transactions, err = NewSqliteTransactionsTable(d.db) if err != nil { return nil, err } - stateBlock, err := NewSqliteStateBlockTable(d.db, writer) + stateBlock, err := NewSqliteStateBlockTable(d.db) if err != nil { return nil, err } - stateSnapshot, err := NewSqliteStateSnapshotTable(d.db, writer) + stateSnapshot, err := NewSqliteStateSnapshotTable(d.db) if err != nil { return nil, err } - d.prevEvents, err = NewSqlitePrevEventsTable(d.db, writer) + d.prevEvents, err = NewSqlitePrevEventsTable(d.db) if err != nil { return nil, err } - roomAliases, err := NewSqliteRoomAliasesTable(d.db, writer) + roomAliases, err := NewSqliteRoomAliasesTable(d.db) if err != nil { return nil, err } - d.invites, err = NewSqliteInvitesTable(d.db, writer) + d.invites, err = NewSqliteInvitesTable(d.db) if err != nil { return nil, err } - d.membership, err = NewSqliteMembershipTable(d.db, writer) + d.membership, err = NewSqliteMembershipTable(d.db) if err != nil { return nil, err } - published, err := NewSqlitePublishedTable(d.db, writer) + published, err := NewSqlitePublishedTable(d.db) if err != nil { return nil, err } - redactions, err := NewSqliteRedactionsTable(d.db, writer) + redactions, err := NewSqliteRedactionsTable(d.db) if err != nil { return nil, err } d.Database = shared.Database{ DB: d.db, + Writer: sqlutil.NewTransactionWriter(), EventsTable: d.events, EventTypesTable: d.eventTypes, EventStateKeysTable: d.eventStateKeys, diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go index 65c18a8a..029122c5 100644 --- a/roomserver/storage/sqlite3/transactions_table.go +++ b/roomserver/storage/sqlite3/transactions_table.go @@ -45,15 +45,13 @@ const selectTransactionEventIDSQL = ` type transactionStatements struct { db *sql.DB - writer *sqlutil.TransactionWriter insertTransactionStmt *sql.Stmt selectTransactionEventIDStmt *sql.Stmt } -func NewSqliteTransactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Transactions, error) { +func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) { s := &transactionStatements{ - db: db, - writer: writer, + db: db, } _, err := db.Exec(transactionsSchema) if err != nil { @@ -72,14 +70,12 @@ func (s *transactionStatements) InsertTransaction( sessionID int64, userID string, eventID string, -) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt) - _, err := stmt.ExecContext( - ctx, transactionID, sessionID, userID, eventID, - ) - return err - }) +) error { + stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt) + _, err := stmt.ExecContext( + ctx, transactionID, sessionID, userID, eventID, + ) + return err } func (s *transactionStatements) SelectTransactionEventID( |