aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--syncapi/storage/postgres/notification_data_table.go2
-rw-r--r--syncapi/storage/sqlite3/notification_data_table.go16
-rw-r--r--syncapi/storage/sqlite3/stream_id_table.go8
-rw-r--r--syncapi/storage/sqlite3/syncserver.go2
4 files changed, 21 insertions, 7 deletions
diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go
index f3fc4451..9cd8b736 100644
--- a/syncapi/storage/postgres/notification_data_table.go
+++ b/syncapi/storage/postgres/notification_data_table.go
@@ -58,7 +58,7 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
(user_id, room_id, notification_count, highlight_count)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, room_id)
- DO UPDATE SET notification_count = $3, highlight_count = $4
+ DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4
RETURNING id`
const selectUserUnreadNotificationCountsSQL = `SELECT
diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go
index 4b3f074d..eaa11a8c 100644
--- a/syncapi/storage/sqlite3/notification_data_table.go
+++ b/syncapi/storage/sqlite3/notification_data_table.go
@@ -25,12 +25,14 @@ import (
"github.com/matrix-org/dendrite/syncapi/types"
)
-func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error) {
+func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.NotificationData, error) {
_, err := db.Exec(notificationDataSchema)
if err != nil {
return nil, err
}
- r := &notificationDataStatements{}
+ r := &notificationDataStatements{
+ streamIDStatements: streamID,
+ }
return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
@@ -39,6 +41,7 @@ func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error)
}
type notificationDataStatements struct {
+ streamIDStatements *StreamIDStatements
upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt
selectMaxID *sql.Stmt
@@ -58,8 +61,7 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
(user_id, room_id, notification_count, highlight_count)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, room_id)
- DO UPDATE SET notification_count = $3, highlight_count = $4
- RETURNING id`
+ DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7`
const selectUserUnreadNotificationCountsSQL = `SELECT
id, room_id, notification_count, highlight_count
@@ -71,7 +73,11 @@ const selectUserUnreadNotificationCountsSQL = `SELECT
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
- err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos)
+ pos, err = r.streamIDStatements.nextNotificationID(ctx, nil)
+ if err != nil {
+ return
+ }
+ _, err = r.upsertRoomUnreadCounts.ExecContext(ctx, userID, roomID, notificationCount, highlightCount, pos, notificationCount, highlightCount)
return
}
diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go
index 71980b80..1160a437 100644
--- a/syncapi/storage/sqlite3/stream_id_table.go
+++ b/syncapi/storage/sqlite3/stream_id_table.go
@@ -26,6 +26,8 @@ INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("presence", 0)
ON CONFLICT DO NOTHING;
+INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("notification", 0)
+ ON CONFLICT DO NOTHING;
`
const increaseStreamIDStmt = "" +
@@ -78,3 +80,9 @@ func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (p
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
return
}
+
+func (s *StreamIDStatements) nextNotificationID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
+ increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
+ err = increaseStmt.QueryRowContext(ctx, "notification").Scan(&pos)
+ return
+}
diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go
index 65b2bb38..5c5eb0f5 100644
--- a/syncapi/storage/sqlite3/syncserver.go
+++ b/syncapi/storage/sqlite3/syncserver.go
@@ -95,7 +95,7 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil {
return err
}
- notificationData, err := NewSqliteNotificationDataTable(d.db)
+ notificationData, err := NewSqliteNotificationDataTable(d.db, &d.streamID)
if err != nil {
return err
}