aboutsummaryrefslogtreecommitdiff
path: root/syncapi/storage/postgres
diff options
context:
space:
mode:
Diffstat (limited to 'syncapi/storage/postgres')
-rw-r--r--syncapi/storage/postgres/receipt_table.go17
-rw-r--r--syncapi/storage/postgres/send_to_device_table.go12
2 files changed, 19 insertions, 10 deletions
diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go
index 23c66910..73bf4179 100644
--- a/syncapi/storage/postgres/receipt_table.go
+++ b/syncapi/storage/postgres/receipt_table.go
@@ -55,7 +55,7 @@ const upsertReceipt = "" +
" RETURNING id"
const selectRoomReceipts = "" +
- "SELECT room_id, receipt_type, user_id, event_id, receipt_ts" +
+ "SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" +
" WHERE room_id = ANY($1) AND id > $2"
@@ -95,22 +95,27 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
return
}
-func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) {
+func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
+ lastPos := types.StreamPosition(0)
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
if err != nil {
- return nil, fmt.Errorf("unable to query room receipts: %w", err)
+ return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed")
var res []api.OutputReceiptEvent
for rows.Next() {
r := api.OutputReceiptEvent{}
- err = rows.Scan(&r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
+ var id types.StreamPosition
+ err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
if err != nil {
- return res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
+ return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
}
res = append(res, r)
+ if id > lastPos {
+ lastPos = id
+ }
}
- return res, rows.Err()
+ return lastPos, res, rows.Err()
}
func (s *receiptStatements) SelectMaxReceiptID(
diff --git a/syncapi/storage/postgres/send_to_device_table.go b/syncapi/storage/postgres/send_to_device_table.go
index be9c347b..ac60989c 100644
--- a/syncapi/storage/postgres/send_to_device_table.go
+++ b/syncapi/storage/postgres/send_to_device_table.go
@@ -49,6 +49,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
const insertSendToDeviceMessageSQL = `
INSERT INTO syncapi_send_to_device (user_id, device_id, content)
VALUES ($1, $2, $3)
+ RETURNING id
`
const countSendToDeviceMessagesSQL = `
@@ -107,8 +108,8 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
-) (err error) {
- _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
+) (pos types.StreamPosition, err error) {
+ err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).QueryRowContext(ctx, userID, deviceID, content).Scan(&pos)
return
}
@@ -124,7 +125,7 @@ func (s *sendToDeviceStatements) CountSendToDeviceMessages(
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
-) (events []types.SendToDeviceEvent, err error) {
+) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
if err != nil {
return
@@ -152,9 +153,12 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
}
}
events = append(events, event)
+ if types.StreamPosition(id) > lastPos {
+ lastPos = types.StreamPosition(id)
+ }
}
- return events, rows.Err()
+ return lastPos, events, rows.Err()
}
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(