aboutsummaryrefslogtreecommitdiff
path: root/federationsender/storage/postgres/queue_pdus_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'federationsender/storage/postgres/queue_pdus_table.go')
-rw-r--r--federationsender/storage/postgres/queue_pdus_table.go73
1 files changed, 39 insertions, 34 deletions
diff --git a/federationsender/storage/postgres/queue_pdus_table.go b/federationsender/storage/postgres/queue_pdus_table.go
index dab6003e..95a3b9ee 100644
--- a/federationsender/storage/postgres/queue_pdus_table.go
+++ b/federationsender/storage/postgres/queue_pdus_table.go
@@ -18,6 +18,7 @@ import (
"context"
"database/sql"
+ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
@@ -41,10 +42,10 @@ const insertQueuePDUSQL = "" +
"INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
-const deleteQueueTransactionPDUsSQL = "" +
- "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND transaction_id = $2"
+const deleteQueuePDUSQL = "" +
+ "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid = ANY($2)"
-const selectQueueNextTransactionIDSQL = "" +
+const selectQueuePDUNextTransactionIDSQL = "" +
"SELECT transaction_id FROM federationsender_queue_pdus" +
" WHERE server_name = $1" +
" ORDER BY transaction_id ASC" +
@@ -55,7 +56,7 @@ const selectQueuePDUsByTransactionSQL = "" +
" WHERE server_name = $1 AND transaction_id = $2" +
" LIMIT $3"
-const selectQueueReferenceJSONCountSQL = "" +
+const selectQueuePDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE json_nid = $1"
@@ -63,49 +64,53 @@ const selectQueuePDUsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE server_name = $1"
-const selectQueueServerNamesSQL = "" +
+const selectQueuePDUServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
type queuePDUsStatements struct {
- insertQueuePDUStmt *sql.Stmt
- deleteQueueTransactionPDUsStmt *sql.Stmt
- selectQueueNextTransactionIDStmt *sql.Stmt
- selectQueuePDUsByTransactionStmt *sql.Stmt
- selectQueueReferenceJSONCountStmt *sql.Stmt
- selectQueuePDUsCountStmt *sql.Stmt
- selectQueueServerNamesStmt *sql.Stmt
+ db *sql.DB
+ insertQueuePDUStmt *sql.Stmt
+ deleteQueuePDUsStmt *sql.Stmt
+ selectQueuePDUNextTransactionIDStmt *sql.Stmt
+ selectQueuePDUsByTransactionStmt *sql.Stmt
+ selectQueuePDUReferenceJSONCountStmt *sql.Stmt
+ selectQueuePDUsCountStmt *sql.Stmt
+ selectQueuePDUServerNamesStmt *sql.Stmt
}
-func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) {
- _, err = db.Exec(queuePDUsSchema)
+func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
+ s = &queuePDUsStatements{
+ db: db,
+ }
+ _, err = s.db.Exec(queuePDUsSchema)
if err != nil {
return
}
- if s.insertQueuePDUStmt, err = db.Prepare(insertQueuePDUSQL); err != nil {
+ if s.insertQueuePDUStmt, err = s.db.Prepare(insertQueuePDUSQL); err != nil {
return
}
- if s.deleteQueueTransactionPDUsStmt, err = db.Prepare(deleteQueueTransactionPDUsSQL); err != nil {
+ if s.deleteQueuePDUsStmt, err = s.db.Prepare(deleteQueuePDUSQL); err != nil {
return
}
- if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil {
+ if s.selectQueuePDUNextTransactionIDStmt, err = s.db.Prepare(selectQueuePDUNextTransactionIDSQL); err != nil {
return
}
- if s.selectQueuePDUsByTransactionStmt, err = db.Prepare(selectQueuePDUsByTransactionSQL); err != nil {
+ if s.selectQueuePDUsByTransactionStmt, err = s.db.Prepare(selectQueuePDUsByTransactionSQL); err != nil {
return
}
- if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil {
+ if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil {
return
}
- if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil {
+ if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil {
return
}
- if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil {
+ if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil {
return
}
return
}
-func (s *queuePDUsStatements) insertQueuePDU(
+func (s *queuePDUsStatements) InsertQueuePDU(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
@@ -122,21 +127,21 @@ func (s *queuePDUsStatements) insertQueuePDU(
return err
}
-func (s *queuePDUsStatements) deleteQueueTransaction(
+func (s *queuePDUsStatements) DeleteQueuePDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
- transactionID gomatrixserverlib.TransactionID,
+ jsonNIDs []int64,
) error {
- stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionPDUsStmt)
- _, err := stmt.ExecContext(ctx, serverName, transactionID)
+ stmt := sqlutil.TxStmt(txn, s.deleteQueuePDUsStmt)
+ _, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs))
return err
}
-func (s *queuePDUsStatements) selectQueueNextTransactionID(
+func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (gomatrixserverlib.TransactionID, error) {
var transactionID gomatrixserverlib.TransactionID
- stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt)
+ stmt := sqlutil.TxStmt(txn, s.selectQueuePDUNextTransactionIDStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID)
if err == sql.ErrNoRows {
return "", nil
@@ -144,11 +149,11 @@ func (s *queuePDUsStatements) selectQueueNextTransactionID(
return transactionID, err
}
-func (s *queuePDUsStatements) selectQueueReferenceJSONCount(
+func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) {
var count int64
- stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt)
+ stmt := sqlutil.TxStmt(txn, s.selectQueuePDUReferenceJSONCountStmt)
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
@@ -159,7 +164,7 @@ func (s *queuePDUsStatements) selectQueueReferenceJSONCount(
return count, err
}
-func (s *queuePDUsStatements) selectQueuePDUCount(
+func (s *queuePDUsStatements) SelectQueuePDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
@@ -174,7 +179,7 @@ func (s *queuePDUsStatements) selectQueuePDUCount(
return count, err
}
-func (s *queuePDUsStatements) selectQueuePDUs(
+func (s *queuePDUsStatements) SelectQueuePDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
transactionID gomatrixserverlib.TransactionID,
@@ -198,10 +203,10 @@ func (s *queuePDUsStatements) selectQueuePDUs(
return result, rows.Err()
}
-func (s *queuePDUsStatements) selectQueueServerNames(
+func (s *queuePDUsStatements) SelectQueuePDUServerNames(
ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) {
- stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt)
+ stmt := sqlutil.TxStmt(txn, s.selectQueuePDUServerNamesStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err