diff options
Diffstat (limited to 'federationsender/storage/postgres/queue_pdus_table.go')
-rw-r--r-- | federationsender/storage/postgres/queue_pdus_table.go | 73 |
1 files changed, 39 insertions, 34 deletions
diff --git a/federationsender/storage/postgres/queue_pdus_table.go b/federationsender/storage/postgres/queue_pdus_table.go index dab6003e..95a3b9ee 100644 --- a/federationsender/storage/postgres/queue_pdus_table.go +++ b/federationsender/storage/postgres/queue_pdus_table.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" @@ -41,10 +42,10 @@ const insertQueuePDUSQL = "" + "INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" + " VALUES ($1, $2, $3)" -const deleteQueueTransactionPDUsSQL = "" + - "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND transaction_id = $2" +const deleteQueuePDUSQL = "" + + "DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid = ANY($2)" -const selectQueueNextTransactionIDSQL = "" + +const selectQueuePDUNextTransactionIDSQL = "" + "SELECT transaction_id FROM federationsender_queue_pdus" + " WHERE server_name = $1" + " ORDER BY transaction_id ASC" + @@ -55,7 +56,7 @@ const selectQueuePDUsByTransactionSQL = "" + " WHERE server_name = $1 AND transaction_id = $2" + " LIMIT $3" -const selectQueueReferenceJSONCountSQL = "" + +const selectQueuePDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" @@ -63,49 +64,53 @@ const selectQueuePDUsCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE server_name = $1" -const selectQueueServerNamesSQL = "" + +const selectQueuePDUServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" type queuePDUsStatements struct { - insertQueuePDUStmt *sql.Stmt - deleteQueueTransactionPDUsStmt *sql.Stmt - selectQueueNextTransactionIDStmt *sql.Stmt - selectQueuePDUsByTransactionStmt *sql.Stmt - selectQueueReferenceJSONCountStmt *sql.Stmt - selectQueuePDUsCountStmt *sql.Stmt - selectQueueServerNamesStmt *sql.Stmt + db *sql.DB + insertQueuePDUStmt *sql.Stmt + deleteQueuePDUsStmt *sql.Stmt + selectQueuePDUNextTransactionIDStmt *sql.Stmt + selectQueuePDUsByTransactionStmt *sql.Stmt + selectQueuePDUReferenceJSONCountStmt *sql.Stmt + selectQueuePDUsCountStmt *sql.Stmt + selectQueuePDUServerNamesStmt *sql.Stmt } -func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(queuePDUsSchema) +func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { + s = &queuePDUsStatements{ + db: db, + } + _, err = s.db.Exec(queuePDUsSchema) if err != nil { return } - if s.insertQueuePDUStmt, err = db.Prepare(insertQueuePDUSQL); err != nil { + if s.insertQueuePDUStmt, err = s.db.Prepare(insertQueuePDUSQL); err != nil { return } - if s.deleteQueueTransactionPDUsStmt, err = db.Prepare(deleteQueueTransactionPDUsSQL); err != nil { + if s.deleteQueuePDUsStmt, err = s.db.Prepare(deleteQueuePDUSQL); err != nil { return } - if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil { + if s.selectQueuePDUNextTransactionIDStmt, err = s.db.Prepare(selectQueuePDUNextTransactionIDSQL); err != nil { return } - if s.selectQueuePDUsByTransactionStmt, err = db.Prepare(selectQueuePDUsByTransactionSQL); err != nil { + if s.selectQueuePDUsByTransactionStmt, err = s.db.Prepare(selectQueuePDUsByTransactionSQL); err != nil { return } - if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil { + if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil { return } - if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { + if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil { return } - if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { + if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil { return } return } -func (s *queuePDUsStatements) insertQueuePDU( +func (s *queuePDUsStatements) InsertQueuePDU( ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, @@ -122,21 +127,21 @@ func (s *queuePDUsStatements) insertQueuePDU( return err } -func (s *queuePDUsStatements) deleteQueueTransaction( +func (s *queuePDUsStatements) DeleteQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, - transactionID gomatrixserverlib.TransactionID, + jsonNIDs []int64, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionPDUsStmt) - _, err := stmt.ExecContext(ctx, serverName, transactionID) + stmt := sqlutil.TxStmt(txn, s.deleteQueuePDUsStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs)) return err } -func (s *queuePDUsStatements) selectQueueNextTransactionID( +func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (gomatrixserverlib.TransactionID, error) { var transactionID gomatrixserverlib.TransactionID - stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUNextTransactionIDStmt) err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID) if err == sql.ErrNoRows { return "", nil @@ -144,11 +149,11 @@ func (s *queuePDUsStatements) selectQueueNextTransactionID( return transactionID, err } -func (s *queuePDUsStatements) selectQueueReferenceJSONCount( +func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( ctx context.Context, txn *sql.Tx, jsonNID int64, ) (int64, error) { var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUReferenceJSONCountStmt) err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) if err == sql.ErrNoRows { // It's acceptable for there to be no rows referencing a given @@ -159,7 +164,7 @@ func (s *queuePDUsStatements) selectQueueReferenceJSONCount( return count, err } -func (s *queuePDUsStatements) selectQueuePDUCount( +func (s *queuePDUsStatements) SelectQueuePDUCount( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (int64, error) { var count int64 @@ -174,7 +179,7 @@ func (s *queuePDUsStatements) selectQueuePDUCount( return count, err } -func (s *queuePDUsStatements) selectQueuePDUs( +func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, @@ -198,10 +203,10 @@ func (s *queuePDUsStatements) selectQueuePDUs( return result, rows.Err() } -func (s *queuePDUsStatements) selectQueueServerNames( +func (s *queuePDUsStatements) SelectQueuePDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { - stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUServerNamesStmt) rows, err := stmt.QueryContext(ctx) if err != nil { return nil, err |