diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-09-09 13:06:42 +0100 |
---|---|---|
committer | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-09-09 13:06:42 +0100 |
commit | 646de03d60fa1ca78dbf0b4d5418600d540fb881 (patch) | |
tree | bb6f6638a9e23d418a1812516fc0b6982e0c47bf /syncapi/storage/sqlite3 | |
parent | 34e1dc210bf4698d2b46a0be71e6dfc10a5a6365 (diff) |
More writer fixes in the Sync API
Diffstat (limited to 'syncapi/storage/sqlite3')
-rw-r--r-- | syncapi/storage/sqlite3/ignores_table.go | 9 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/notification_data_table.go | 10 |
2 files changed, 10 insertions, 9 deletions
diff --git a/syncapi/storage/sqlite3/ignores_table.go b/syncapi/storage/sqlite3/ignores_table.go index f4afca55..5ee1a9fa 100644 --- a/syncapi/storage/sqlite3/ignores_table.go +++ b/syncapi/storage/sqlite3/ignores_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -61,10 +62,10 @@ func NewSqliteIgnoresTable(db *sql.DB) (tables.Ignores, error) { } func (s *ignoresStatements) SelectIgnores( - ctx context.Context, userID string, + ctx context.Context, txn *sql.Tx, userID string, ) (*types.IgnoredUsers, error) { var ignoresData []byte - err := s.selectIgnoresStmt.QueryRowContext(ctx, userID).Scan(&ignoresData) + err := sqlutil.TxStmt(txn, s.selectIgnoresStmt).QueryRowContext(ctx, userID).Scan(&ignoresData) if err != nil { return nil, err } @@ -76,12 +77,12 @@ func (s *ignoresStatements) SelectIgnores( } func (s *ignoresStatements) UpsertIgnores( - ctx context.Context, userID string, ignores *types.IgnoredUsers, + ctx context.Context, txn *sql.Tx, userID string, ignores *types.IgnoredUsers, ) error { ignoresJSON, err := json.Marshal(ignores) if err != nil { return err } - _, err = s.upsertIgnoresStmt.ExecContext(ctx, userID, ignoresJSON) + _, err = sqlutil.TxStmt(txn, s.upsertIgnoresStmt).ExecContext(ctx, userID, ignoresJSON) return err } diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go index eaa11a8c..66d4d438 100644 --- a/syncapi/storage/sqlite3/notification_data_table.go +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -72,7 +72,7 @@ const selectUserUnreadNotificationCountsSQL = `SELECT const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` -func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { +func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { pos, err = r.streamIDStatements.nextNotificationID(ctx, nil) if err != nil { return @@ -81,8 +81,8 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, return } -func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { - rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl) +func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { + rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) if err != nil { return nil, err } @@ -107,8 +107,8 @@ func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, return roomCounts, rows.Err() } -func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) { +func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) { var id int64 - err := r.selectMaxID.QueryRowContext(ctx).Scan(&id) + err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id) return id, err } |