aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2020-07-21 10:48:49 +0100
committerGitHub <noreply@github.com>2020-07-21 10:48:49 +0100
commitd76eb1b99491f644be035a631a08b5874065e4d7 (patch)
treee69476061eea083f7b8c670927a3d3a796c8abbf
parent489f34fed7fccd59c0788536894561157d1089c1 (diff)
Use TransactionWriter in roomserver SQLite (#1208)
-rw-r--r--roomserver/storage/sqlite3/event_json_table.go13
-rw-r--r--roomserver/storage/sqlite3/event_state_keys_table.go23
-rw-r--r--roomserver/storage/sqlite3/event_types_table.go22
-rw-r--r--roomserver/storage/sqlite3/events_table.go55
-rw-r--r--roomserver/storage/sqlite3/invite_table.go71
-rw-r--r--roomserver/storage/sqlite3/membership_table.go27
-rw-r--r--roomserver/storage/sqlite3/previous_events_table.go19
-rw-r--r--roomserver/storage/sqlite3/published_table.go14
-rw-r--r--roomserver/storage/sqlite3/redactions_table.go23
-rw-r--r--roomserver/storage/sqlite3/room_aliases_table.go20
-rw-r--r--roomserver/storage/sqlite3/rooms_table.go36
-rw-r--r--roomserver/storage/sqlite3/state_block_table.go41
-rw-r--r--roomserver/storage/sqlite3/state_snapshot_table.go24
-rw-r--r--roomserver/storage/sqlite3/transactions_table.go19
14 files changed, 259 insertions, 148 deletions
diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go
index 6368675b..64795d02 100644
--- a/roomserver/storage/sqlite3/event_json_table.go
+++ b/roomserver/storage/sqlite3/event_json_table.go
@@ -49,13 +49,16 @@ const bulkSelectEventJSONSQL = `
type eventJSONStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
insertEventJSONStmt *sql.Stmt
bulkSelectEventJSONStmt *sql.Stmt
}
func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
- s := &eventJSONStatements{}
- s.db = db
+ s := &eventJSONStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(eventJSONSchema)
if err != nil {
return nil, err
@@ -69,8 +72,10 @@ func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
func (s *eventJSONStatements) InsertEventJSON(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
) error {
- _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
+ return err
+ })
}
func (s *eventJSONStatements) BulkSelectEventJSON(
diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go
index cbea8428..3e9f2e61 100644
--- a/roomserver/storage/sqlite3/event_state_keys_table.go
+++ b/roomserver/storage/sqlite3/event_state_keys_table.go
@@ -64,6 +64,7 @@ const bulkSelectEventStateKeyNIDSQL = `
type eventStateKeyStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
insertEventStateKeyNIDStmt *sql.Stmt
selectEventStateKeyNIDStmt *sql.Stmt
bulkSelectEventStateKeyNIDStmt *sql.Stmt
@@ -71,8 +72,10 @@ type eventStateKeyStatements struct {
}
func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
- s := &eventStateKeyStatements{}
- s.db = db
+ s := &eventStateKeyStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(eventStateKeysSchema)
if err != nil {
return nil, err
@@ -89,12 +92,18 @@ func (s *eventStateKeyStatements) InsertEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64
- var err error
- var res sql.Result
- insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
- if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil {
+ err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
+ res, err := insertStmt.ExecContext(ctx, eventStateKey)
+ if err != nil {
+ return err
+ }
eventStateKeyNID, err = res.LastInsertId()
- }
+ if err != nil {
+ return err
+ }
+ return nil
+ })
return types.EventStateKeyNID(eventStateKeyNID), err
}
diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go
index c9a461f9..fd4a2e42 100644
--- a/roomserver/storage/sqlite3/event_types_table.go
+++ b/roomserver/storage/sqlite3/event_types_table.go
@@ -78,6 +78,7 @@ const bulkSelectEventTypeNIDSQL = `
type eventTypeStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
insertEventTypeNIDStmt *sql.Stmt
insertEventTypeNIDResultStmt *sql.Stmt
selectEventTypeNIDStmt *sql.Stmt
@@ -85,8 +86,10 @@ type eventTypeStatements struct {
}
func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
- s := &eventTypeStatements{}
- s.db = db
+ s := &eventTypeStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(eventTypesSchema)
if err != nil {
return nil, err
@@ -104,12 +107,15 @@ func (s *eventTypeStatements) InsertEventTypeNID(
ctx context.Context, tx *sql.Tx, eventType string,
) (types.EventTypeNID, error) {
var eventTypeNID int64
- var err error
- insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt)
- resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt)
- if _, err = insertStmt.ExecContext(ctx, eventType); err == nil {
- err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID)
- }
+ err := s.writer.Do(s.db, tx, func(tx *sql.Tx) error {
+ insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt)
+ resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt)
+ _, err := insertStmt.ExecContext(ctx, eventType)
+ if err != nil {
+ return err
+ }
+ return resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID)
+ })
return types.EventTypeNID(eventTypeNID), err
}
diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go
index d66db469..378441c3 100644
--- a/roomserver/storage/sqlite3/events_table.go
+++ b/roomserver/storage/sqlite3/events_table.go
@@ -99,6 +99,7 @@ const selectRoomNIDForEventNIDSQL = "" +
type eventStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt
@@ -115,8 +116,10 @@ type eventStatements struct {
}
func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
- s := &eventStatements{}
- s.db = db
+ s := &eventStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(eventsSchema)
if err != nil {
return nil, err
@@ -151,19 +154,23 @@ func (s *eventStatements) InsertEvent(
depth int64,
) (types.EventNID, types.StateSnapshotNID, error) {
// attempt to insert: the last_row_id is the event NID
- insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
- result, err := insertStmt.ExecContext(
- ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
- eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
- )
- if err != nil {
- return 0, 0, err
- }
- modified, err := result.RowsAffected()
- if modified == 0 && err == nil {
- return 0, 0, sql.ErrNoRows
- }
- eventNID, err := result.LastInsertId()
+ var eventNID int64
+ err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
+ result, err := insertStmt.ExecContext(
+ ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
+ eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
+ )
+ if err != nil {
+ return err
+ }
+ modified, err := result.RowsAffected()
+ if modified == 0 && err == nil {
+ return sql.ErrNoRows
+ }
+ eventNID, err = result.LastInsertId()
+ return err
+ })
return types.EventNID(eventNID), 0, err
}
@@ -279,8 +286,10 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
func (s *eventStatements) UpdateEventState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
- _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
- return err
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
+ return err
+ })
}
func (s *eventStatements) SelectEventSentToOutput(
@@ -288,17 +297,15 @@ func (s *eventStatements) SelectEventSentToOutput(
) (sentToOutput bool, err error) {
selectStmt := sqlutil.TxStmt(txn, s.selectEventSentToOutputStmt)
err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
- //err = s.selectEventSentToOutputStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
- if err != nil {
- }
return
}
func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
- updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
- _, err := updateStmt.ExecContext(ctx, int64(eventNID))
- //_, err := s.updateEventSentToOutputStmt.ExecContext(ctx, int64(eventNID))
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
+ _, err := updateStmt.ExecContext(ctx, int64(eventNID))
+ return err
+ })
}
func (s *eventStatements) SelectEventID(
diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go
index 8b6cbe3f..e806eab6 100644
--- a/roomserver/storage/sqlite3/invite_table.go
+++ b/roomserver/storage/sqlite3/invite_table.go
@@ -63,6 +63,8 @@ SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_ni
`
type inviteStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertInviteEventStmt *sql.Stmt
selectInviteActiveForUserInRoomStmt *sql.Stmt
updateInviteRetiredStmt *sql.Stmt
@@ -70,7 +72,10 @@ type inviteStatements struct {
}
func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
- s := &inviteStatements{}
+ s := &inviteStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(inviteSchema)
if err != nil {
return nil, err
@@ -90,42 +95,48 @@ func (s *inviteStatements) InsertInviteEvent(
targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte,
) (bool, error) {
- stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
- result, err := stmt.ExecContext(
- ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
- )
- if err != nil {
- return false, err
- }
- count, err := result.RowsAffected()
- if err != nil {
- return false, err
- }
- return count != 0, nil
+ var count int64
+ err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
+ result, err := stmt.ExecContext(
+ ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
+ )
+ if err != nil {
+ return err
+ }
+ count, err = result.RowsAffected()
+ if err != nil {
+ return err
+ }
+ return nil
+ })
+ return count != 0, err
}
func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) {
- // gather all the event IDs we will retire
- stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
- rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
- if err != nil {
- return nil, err
- }
- defer (func() { err = rows.Close() })()
- for rows.Next() {
- var inviteEventID string
- if err = rows.Scan(&inviteEventID); err != nil {
- return nil, err
+ err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ // gather all the event IDs we will retire
+ stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
+ rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
+ if err != nil {
+ return err
}
- eventIDs = append(eventIDs, inviteEventID)
- }
-
- // now retire the invites
- stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
- _, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
+ defer (func() { err = rows.Close() })()
+ for rows.Next() {
+ var inviteEventID string
+ if err = rows.Scan(&inviteEventID); err != nil {
+ return err
+ }
+ eventIDs = append(eventIDs, inviteEventID)
+ }
+ // now retire the invites
+ stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
+ _, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
+ return err
+ })
return
}
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
index 6f0d763e..6dd8bd83 100644
--- a/roomserver/storage/sqlite3/membership_table.go
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -76,6 +76,8 @@ const updateMembershipSQL = "" +
" WHERE room_nid = $4 AND target_nid = $5"
type membershipStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt
@@ -87,7 +89,10 @@ type membershipStatements struct {
}
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
- s := &membershipStatements{}
+ s := &membershipStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(membershipSchema)
if err != nil {
return nil, err
@@ -110,9 +115,11 @@ func (s *membershipStatements) InsertMembership(
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool,
) error {
- stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
- _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
+ _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
+ return err
+ })
}
func (s *membershipStatements) SelectMembershipForUpdate(
@@ -194,9 +201,11 @@ func (s *membershipStatements) UpdateMembership(
senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID,
) error {
- stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
- _, err := stmt.ExecContext(
- ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
- )
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
+ _, err := stmt.ExecContext(
+ ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
+ )
+ return err
+ })
}
diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go
index 549aecfb..28b5d18f 100644
--- a/roomserver/storage/sqlite3/previous_events_table.go
+++ b/roomserver/storage/sqlite3/previous_events_table.go
@@ -53,12 +53,17 @@ const selectPreviousEventExistsSQL = `
`
type previousEventStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertPreviousEventStmt *sql.Stmt
selectPreviousEventExistsStmt *sql.Stmt
}
func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
- s := &previousEventStatements{}
+ s := &previousEventStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(previousEventSchema)
if err != nil {
return nil, err
@@ -77,11 +82,13 @@ func (s *previousEventStatements) InsertPreviousEvent(
previousEventReferenceSHA256 []byte,
eventNID types.EventNID,
) error {
- stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
- _, err := stmt.ExecContext(
- ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
- )
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
+ _, err := stmt.ExecContext(
+ ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
+ )
+ return err
+ })
}
// Check if the event reference exists
diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go
index 9995fff6..96575241 100644
--- a/roomserver/storage/sqlite3/published_table.go
+++ b/roomserver/storage/sqlite3/published_table.go
@@ -19,6 +19,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
)
@@ -43,13 +44,18 @@ const selectPublishedSQL = "" +
"SELECT published FROM roomserver_published WHERE room_id = $1"
type publishedStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
upsertPublishedStmt *sql.Stmt
selectAllPublishedStmt *sql.Stmt
selectPublishedStmt *sql.Stmt
}
func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
- s := &publishedStatements{}
+ s := &publishedStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(publishedSchema)
if err != nil {
return nil, err
@@ -64,8 +70,10 @@ func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
func (s *publishedStatements) UpsertRoomPublished(
ctx context.Context, roomID string, published bool,
) (err error) {
- _, err = s.upsertPublishedStmt.ExecContext(ctx, roomID, published)
- return
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ _, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published)
+ return err
+ })
}
func (s *publishedStatements) SelectPublishedFromRoomID(
diff --git a/roomserver/storage/sqlite3/redactions_table.go b/roomserver/storage/sqlite3/redactions_table.go
index 1cddb9b4..d2bd2a20 100644
--- a/roomserver/storage/sqlite3/redactions_table.go
+++ b/roomserver/storage/sqlite3/redactions_table.go
@@ -52,6 +52,8 @@ const markRedactionValidatedSQL = "" +
" UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
type redactionStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertRedactionStmt *sql.Stmt
selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt
@@ -59,7 +61,10 @@ type redactionStatements struct {
}
func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
- s := &redactionStatements{}
+ s := &redactionStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(redactionsSchema)
if err != nil {
return nil, err
@@ -76,9 +81,11 @@ func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
func (s *redactionStatements) InsertRedaction(
ctx context.Context, txn *sql.Tx, info tables.RedactionInfo,
) error {
- stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
- _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
+ _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
+ return err
+ })
}
func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
@@ -114,7 +121,9 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
func (s *redactionStatements) MarkRedactionValidated(
ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool,
) error {
- stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
- _, err := stmt.ExecContext(ctx, redactionEventID, validated)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
+ _, err := stmt.ExecContext(ctx, redactionEventID, validated)
+ return err
+ })
}
diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go
index da5f9161..096b73f9 100644
--- a/roomserver/storage/sqlite3/room_aliases_table.go
+++ b/roomserver/storage/sqlite3/room_aliases_table.go
@@ -20,6 +20,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
)
@@ -55,6 +56,8 @@ const deleteRoomAliasSQL = `
`
type roomAliasesStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertRoomAliasStmt *sql.Stmt
selectRoomIDFromAliasStmt *sql.Stmt
selectAliasesFromRoomIDStmt *sql.Stmt
@@ -63,7 +66,10 @@ type roomAliasesStatements struct {
}
func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
- s := &roomAliasesStatements{}
+ s := &roomAliasesStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(roomAliasesSchema)
if err != nil {
return nil, err
@@ -80,8 +86,10 @@ func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
func (s *roomAliasesStatements) InsertRoomAlias(
ctx context.Context, alias string, roomID string, creatorUserID string,
) (err error) {
- _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
- return
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ _, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
+ return err
+ })
}
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
@@ -130,6 +138,8 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
func (s *roomAliasesStatements) DeleteRoomAlias(
ctx context.Context, alias string,
) (err error) {
- _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias)
- return
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ _, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias)
+ return err
+ })
}
diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go
index ab695c5d..9eeadea9 100644
--- a/roomserver/storage/sqlite3/rooms_table.go
+++ b/roomserver/storage/sqlite3/rooms_table.go
@@ -64,6 +64,8 @@ const selectRoomVersionForRoomNIDSQL = "" +
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
type roomStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertRoomNIDStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt
selectLatestEventNIDsStmt *sql.Stmt
@@ -74,7 +76,10 @@ type roomStatements struct {
}
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
- s := &roomStatements{}
+ s := &roomStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(roomsSchema)
if err != nil {
return nil, err
@@ -94,9 +99,12 @@ func (s *roomStatements) InsertRoomNID(
ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) {
- var err error
- insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
- if _, err = insertStmt.ExecContext(ctx, roomID, roomVersion); err == nil {
+ err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
+ _, err := insertStmt.ExecContext(ctx, roomID, roomVersion)
+ return err
+ })
+ if err == nil {
return s.SelectRoomNID(ctx, txn, roomID)
} else {
return types.RoomNID(0), err
@@ -155,15 +163,17 @@ func (s *roomStatements) UpdateLatestEventNIDs(
lastEventSentNID types.EventNID,
stateSnapshotNID types.StateSnapshotNID,
) error {
- stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
- _, err := stmt.ExecContext(
- ctx,
- eventNIDsAsArray(eventNIDs),
- int64(lastEventSentNID),
- int64(stateSnapshotNID),
- roomNID,
- )
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ eventNIDsAsArray(eventNIDs),
+ int64(lastEventSentNID),
+ int64(stateSnapshotNID),
+ roomNID,
+ )
+ return err
+ })
}
func (s *roomStatements) SelectRoomVersionForRoomID(
diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go
index c058c783..3d716b64 100644
--- a/roomserver/storage/sqlite3/state_block_table.go
+++ b/roomserver/storage/sqlite3/state_block_table.go
@@ -74,6 +74,7 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" +
type stateBlockStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
insertStateDataStmt *sql.Stmt
selectNextStateBlockNIDStmt *sql.Stmt
bulkSelectStateBlockEntriesStmt *sql.Stmt
@@ -81,8 +82,10 @@ type stateBlockStatements struct {
}
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
- s := &stateBlockStatements{}
- s.db = db
+ s := &stateBlockStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(stateDataSchema)
if err != nil {
return nil, err
@@ -104,24 +107,26 @@ func (s *stateBlockStatements) BulkInsertStateData(
return 0, nil
}
var stateBlockNID types.StateBlockNID
- err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
- if err != nil {
- return 0, err
- }
-
- for _, entry := range entries {
- _, err := txn.Stmt(s.insertStateDataStmt).ExecContext(
- ctx,
- int64(stateBlockNID),
- int64(entry.EventTypeNID),
- int64(entry.EventStateKeyNID),
- int64(entry.EventNID),
- )
+ err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
if err != nil {
- return 0, err
+ return err
}
- }
- return stateBlockNID, nil
+ for _, entry := range entries {
+ _, err := txn.Stmt(s.insertStateDataStmt).ExecContext(
+ ctx,
+ int64(stateBlockNID),
+ int64(entry.EventTypeNID),
+ int64(entry.EventStateKeyNID),
+ int64(entry.EventNID),
+ )
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ return stateBlockNID, err
}
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go
index d077b617..48f1210b 100644
--- a/roomserver/storage/sqlite3/state_snapshot_table.go
+++ b/roomserver/storage/sqlite3/state_snapshot_table.go
@@ -50,13 +50,16 @@ const bulkSelectStateBlockNIDsSQL = "" +
type stateSnapshotStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
- s := &stateSnapshotStatements{}
- s.db = db
+ s := &stateSnapshotStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(stateSnapshotSchema)
if err != nil {
return nil, err
@@ -75,14 +78,19 @@ func (s *stateSnapshotStatements) InsertState(
if err != nil {
return
}
- insertStmt := txn.Stmt(s.insertStateStmt)
- if res, err2 := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)); err2 == nil {
- lastRowID, err3 := res.LastInsertId()
- if err3 != nil {
- err = err3
+ err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ insertStmt := txn.Stmt(s.insertStateStmt)
+ res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
+ if err != nil {
+ return err
+ }
+ lastRowID, err := res.LastInsertId()
+ if err != nil {
+ return err
}
stateNID = types.StateSnapshotNID(lastRowID)
- }
+ return nil
+ })
return
}
diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go
index 1e8de1ca..2f6cff95 100644
--- a/roomserver/storage/sqlite3/transactions_table.go
+++ b/roomserver/storage/sqlite3/transactions_table.go
@@ -44,12 +44,17 @@ const selectTransactionEventIDSQL = `
`
type transactionStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertTransactionStmt *sql.Stmt
selectTransactionEventIDStmt *sql.Stmt
}
func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
- s := &transactionStatements{}
+ s := &transactionStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(transactionsSchema)
if err != nil {
return nil, err
@@ -68,11 +73,13 @@ func (s *transactionStatements) InsertTransaction(
userID string,
eventID string,
) (err error) {
- stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
- _, err = stmt.ExecContext(
- ctx, transactionID, sessionID, userID, eventID,
- )
- return
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
+ _, err := stmt.ExecContext(
+ ctx, transactionID, sessionID, userID, eventID,
+ )
+ return err
+ })
}
func (s *transactionStatements) SelectTransactionEventID(