aboutsummaryrefslogtreecommitdiff
path: root/syncapi
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-09-09 13:06:42 +0100
committerNeil Alexander <neilalexander@users.noreply.github.com>2022-09-09 13:06:42 +0100
commit646de03d60fa1ca78dbf0b4d5418600d540fb881 (patch)
treebb6f6638a9e23d418a1812516fc0b6982e0c47bf /syncapi
parent34e1dc210bf4698d2b46a0be71e6dfc10a5a6365 (diff)
More writer fixes in the Sync API
Diffstat (limited to 'syncapi')
-rw-r--r--syncapi/storage/postgres/ignores_table.go9
-rw-r--r--syncapi/storage/postgres/notification_data_table.go12
-rw-r--r--syncapi/storage/shared/syncserver.go22
-rw-r--r--syncapi/storage/sqlite3/ignores_table.go9
-rw-r--r--syncapi/storage/sqlite3/notification_data_table.go10
-rw-r--r--syncapi/storage/tables/interface.go10
6 files changed, 41 insertions, 31 deletions
diff --git a/syncapi/storage/postgres/ignores_table.go b/syncapi/storage/postgres/ignores_table.go
index 055a1a23..97660725 100644
--- a/syncapi/storage/postgres/ignores_table.go
+++ b/syncapi/storage/postgres/ignores_table.go
@@ -19,6 +19,7 @@ import (
"database/sql"
"encoding/json"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
@@ -61,10 +62,10 @@ func NewPostgresIgnoresTable(db *sql.DB) (tables.Ignores, error) {
}
func (s *ignoresStatements) SelectIgnores(
- ctx context.Context, userID string,
+ ctx context.Context, txn *sql.Tx, userID string,
) (*types.IgnoredUsers, error) {
var ignoresData []byte
- err := s.selectIgnoresStmt.QueryRowContext(ctx, userID).Scan(&ignoresData)
+ err := sqlutil.TxStmt(txn, s.selectIgnoresStmt).QueryRowContext(ctx, userID).Scan(&ignoresData)
if err != nil {
return nil, err
}
@@ -76,12 +77,12 @@ func (s *ignoresStatements) SelectIgnores(
}
func (s *ignoresStatements) UpsertIgnores(
- ctx context.Context, userID string, ignores *types.IgnoredUsers,
+ ctx context.Context, txn *sql.Tx, userID string, ignores *types.IgnoredUsers,
) error {
ignoresJSON, err := json.Marshal(ignores)
if err != nil {
return err
}
- _, err = s.upsertIgnoresStmt.ExecContext(ctx, userID, ignoresJSON)
+ _, err = sqlutil.TxStmt(txn, s.upsertIgnoresStmt).ExecContext(ctx, userID, ignoresJSON)
return err
}
diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go
index 9cd8b736..708c3a9b 100644
--- a/syncapi/storage/postgres/notification_data_table.go
+++ b/syncapi/storage/postgres/notification_data_table.go
@@ -70,13 +70,13 @@ const selectUserUnreadNotificationCountsSQL = `SELECT
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
-func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
- err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos)
+func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
+ err = sqlutil.TxStmt(txn, r.upsertRoomUnreadCounts).QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos)
return
}
-func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
- rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl)
+func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
+ rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl)
if err != nil {
return nil, err
}
@@ -101,8 +101,8 @@ func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context,
return roomCounts, rows.Err()
}
-func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) {
+func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) {
var id int64
- err := r.selectMaxID.QueryRowContext(ctx).Scan(&id)
+ err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id)
return id, err
}
diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go
index b06d2c6a..778ad8b1 100644
--- a/syncapi/storage/shared/syncserver.go
+++ b/syncapi/storage/shared/syncserver.go
@@ -108,7 +108,7 @@ func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.S
}
func (d *Database) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) {
- id, err := d.NotificationData.SelectMaxID(ctx)
+ id, err := d.NotificationData.SelectMaxID(ctx, nil)
if err != nil {
return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err)
}
@@ -1029,15 +1029,15 @@ func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, stream
}
func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
- err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
- pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, userID, roomID, notificationCount, highlightCount)
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, txn, userID, roomID, notificationCount, highlightCount)
return err
})
return
}
func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
- return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to)
+ return d.NotificationData.SelectUserUnreadCounts(ctx, nil, userID, from, to)
}
func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
@@ -1052,15 +1052,23 @@ func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID s
}
func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) {
- return d.Ignores.SelectIgnores(ctx, userID)
+ return d.Ignores.SelectIgnores(ctx, nil, userID)
}
func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error {
- return d.Ignores.UpsertIgnores(ctx, userID, ignores)
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.Ignores.UpsertIgnores(ctx, txn, userID, ignores)
+ })
}
func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
- return d.Presence.UpsertPresence(ctx, nil, userID, statusMsg, presence, lastActiveTS, fromSync)
+ var pos types.StreamPosition
+ var err error
+ _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ pos, err = d.Presence.UpsertPresence(ctx, txn, userID, statusMsg, presence, lastActiveTS, fromSync)
+ return nil
+ })
+ return pos, err
}
func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
diff --git a/syncapi/storage/sqlite3/ignores_table.go b/syncapi/storage/sqlite3/ignores_table.go
index f4afca55..5ee1a9fa 100644
--- a/syncapi/storage/sqlite3/ignores_table.go
+++ b/syncapi/storage/sqlite3/ignores_table.go
@@ -19,6 +19,7 @@ import (
"database/sql"
"encoding/json"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
@@ -61,10 +62,10 @@ func NewSqliteIgnoresTable(db *sql.DB) (tables.Ignores, error) {
}
func (s *ignoresStatements) SelectIgnores(
- ctx context.Context, userID string,
+ ctx context.Context, txn *sql.Tx, userID string,
) (*types.IgnoredUsers, error) {
var ignoresData []byte
- err := s.selectIgnoresStmt.QueryRowContext(ctx, userID).Scan(&ignoresData)
+ err := sqlutil.TxStmt(txn, s.selectIgnoresStmt).QueryRowContext(ctx, userID).Scan(&ignoresData)
if err != nil {
return nil, err
}
@@ -76,12 +77,12 @@ func (s *ignoresStatements) SelectIgnores(
}
func (s *ignoresStatements) UpsertIgnores(
- ctx context.Context, userID string, ignores *types.IgnoredUsers,
+ ctx context.Context, txn *sql.Tx, userID string, ignores *types.IgnoredUsers,
) error {
ignoresJSON, err := json.Marshal(ignores)
if err != nil {
return err
}
- _, err = s.upsertIgnoresStmt.ExecContext(ctx, userID, ignoresJSON)
+ _, err = sqlutil.TxStmt(txn, s.upsertIgnoresStmt).ExecContext(ctx, userID, ignoresJSON)
return err
}
diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go
index eaa11a8c..66d4d438 100644
--- a/syncapi/storage/sqlite3/notification_data_table.go
+++ b/syncapi/storage/sqlite3/notification_data_table.go
@@ -72,7 +72,7 @@ const selectUserUnreadNotificationCountsSQL = `SELECT
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
-func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
+func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
pos, err = r.streamIDStatements.nextNotificationID(ctx, nil)
if err != nil {
return
@@ -81,8 +81,8 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context,
return
}
-func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
- rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl)
+func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
+ rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl)
if err != nil {
return nil, err
}
@@ -107,8 +107,8 @@ func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context,
return roomCounts, rows.Err()
}
-func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) {
+func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) {
var id int64
- err := r.selectMaxID.QueryRowContext(ctx).Scan(&id)
+ err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id)
return id, err
}
diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go
index 468d26ac..193881b4 100644
--- a/syncapi/storage/tables/interface.go
+++ b/syncapi/storage/tables/interface.go
@@ -189,14 +189,14 @@ type Memberships interface {
}
type NotificationData interface {
- UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
- SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error)
- SelectMaxID(ctx context.Context) (int64, error)
+ UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
+ SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error)
+ SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error)
}
type Ignores interface {
- SelectIgnores(ctx context.Context, userID string) (*types.IgnoredUsers, error)
- UpsertIgnores(ctx context.Context, userID string, ignores *types.IgnoredUsers) error
+ SelectIgnores(ctx context.Context, txn *sql.Tx, userID string) (*types.IgnoredUsers, error)
+ UpsertIgnores(ctx context.Context, txn *sql.Tx, userID string, ignores *types.IgnoredUsers) error
}
type Presence interface {