diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-03-17 17:05:21 +0000 |
---|---|---|
committer | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-03-17 17:05:21 +0000 |
commit | 4e64c270dbe5d438325903e4404ed4b9ec43c039 (patch) | |
tree | 703043599b6a3ed316df980de493b7ba03156d4e /syncapi | |
parent | 0fb94fc781a71219d5e537788e976bec1d84382c (diff) |
Various bug fixes and tweaks around invites and membership
Diffstat (limited to 'syncapi')
-rw-r--r-- | syncapi/storage/postgres/invites_table.go | 2 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/invites_table.go | 2 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/stream_id_table.go | 34 |
3 files changed, 8 insertions, 30 deletions
diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index 48ad58c0..97001ae2 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -52,7 +52,7 @@ const insertInviteEventSQL = "" + ") VALUES ($1, $2, $3, $4, FALSE) RETURNING id" const deleteInviteEventSQL = "" + - "UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 RETURNING id" + "UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 AND deleted=FALSE RETURNING id" const selectInviteEventsInRangeSQL = "" + "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" + diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 7498fd68..0a6823cc 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -47,7 +47,7 @@ const insertInviteEventSQL = "" + " VALUES ($1, $2, $3, $4, $5, false)" const deleteInviteEventSQL = "" + - "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2" + "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2 AND deleted=false" const selectInviteEventsInRangeSQL = "" + "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" + diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index b614271d..2be3ae93 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -27,15 +27,12 @@ INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0) ` const increaseStreamIDStmt = "" + - "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" - -const selectStreamIDStmt = "" + - "SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1" + "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" + + " RETURNING stream_id" type streamIDStatements struct { db *sql.DB increaseStreamIDStmt *sql.Stmt - selectStreamIDStmt *sql.Stmt } func (s *streamIDStatements) prepare(db *sql.DB) (err error) { @@ -47,48 +44,29 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) { if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil { return } - if s.selectStreamIDStmt, err = db.Prepare(selectStreamIDStmt); err != nil { - return - } return } func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { - return - } - err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos) + err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos) return } func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "receipt"); err != nil { - return - } - err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos) + err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos) return } func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "invite"); err != nil { - return - } - err = selectStmt.QueryRowContext(ctx, "invite").Scan(&pos) + err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos) return } func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "accountdata"); err != nil { - return - } - err = selectStmt.QueryRowContext(ctx, "accountdata").Scan(&pos) + err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos) return } |