aboutsummaryrefslogtreecommitdiff
path: root/syncapi
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-03-17 17:05:21 +0000
committerNeil Alexander <neilalexander@users.noreply.github.com>2022-03-17 17:05:21 +0000
commit4e64c270dbe5d438325903e4404ed4b9ec43c039 (patch)
tree703043599b6a3ed316df980de493b7ba03156d4e /syncapi
parent0fb94fc781a71219d5e537788e976bec1d84382c (diff)
Various bug fixes and tweaks around invites and membership
Diffstat (limited to 'syncapi')
-rw-r--r--syncapi/storage/postgres/invites_table.go2
-rw-r--r--syncapi/storage/sqlite3/invites_table.go2
-rw-r--r--syncapi/storage/sqlite3/stream_id_table.go34
3 files changed, 8 insertions, 30 deletions
diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go
index 48ad58c0..97001ae2 100644
--- a/syncapi/storage/postgres/invites_table.go
+++ b/syncapi/storage/postgres/invites_table.go
@@ -52,7 +52,7 @@ const insertInviteEventSQL = "" +
") VALUES ($1, $2, $3, $4, FALSE) RETURNING id"
const deleteInviteEventSQL = "" +
- "UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 RETURNING id"
+ "UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 AND deleted=FALSE RETURNING id"
const selectInviteEventsInRangeSQL = "" +
"SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go
index 7498fd68..0a6823cc 100644
--- a/syncapi/storage/sqlite3/invites_table.go
+++ b/syncapi/storage/sqlite3/invites_table.go
@@ -47,7 +47,7 @@ const insertInviteEventSQL = "" +
" VALUES ($1, $2, $3, $4, $5, false)"
const deleteInviteEventSQL = "" +
- "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2"
+ "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2 AND deleted=false"
const selectInviteEventsInRangeSQL = "" +
"SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go
index b614271d..2be3ae93 100644
--- a/syncapi/storage/sqlite3/stream_id_table.go
+++ b/syncapi/storage/sqlite3/stream_id_table.go
@@ -27,15 +27,12 @@ INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0)
`
const increaseStreamIDStmt = "" +
- "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1"
-
-const selectStreamIDStmt = "" +
- "SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"
+ "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
+ " RETURNING stream_id"
type streamIDStatements struct {
db *sql.DB
increaseStreamIDStmt *sql.Stmt
- selectStreamIDStmt *sql.Stmt
}
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
@@ -47,48 +44,29 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil {
return
}
- if s.selectStreamIDStmt, err = db.Prepare(selectStreamIDStmt); err != nil {
- return
- }
return
}
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
- selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
- if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
- return
- }
- err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
+ err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
return
}
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
- selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
- if _, err = increaseStmt.ExecContext(ctx, "receipt"); err != nil {
- return
- }
- err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
+ err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
return
}
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
- selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
- if _, err = increaseStmt.ExecContext(ctx, "invite"); err != nil {
- return
- }
- err = selectStmt.QueryRowContext(ctx, "invite").Scan(&pos)
+ err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
return
}
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
- selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
- if _, err = increaseStmt.ExecContext(ctx, "accountdata"); err != nil {
- return
- }
- err = selectStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
+ err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
return
}