diff options
author | Sam <me@samcday.com> | 2020-09-24 12:10:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-24 11:10:14 +0100 |
commit | a6700331cee70a3ca04c834efdc68cc2ef63947d (patch) | |
tree | 10420639a2a4685b2562ff1d106d1ae9a11712d4 /keyserver | |
parent | 60524f4b994ba070fff74fa47599e0fcc7856fc0 (diff) |
Update all usages of tx.Stmt to sqlutil.TxStmt (#1423)
* Replace all usages of txn.Stmt with sqlutil.TxStmt
Signed-off-by: Sam Day <me@samcday.com>
* Fix sign off link in PR template.
Signed-off-by: Sam Day <me@samcday.com>
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
Diffstat (limited to 'keyserver')
4 files changed, 17 insertions, 14 deletions
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index e5bec8f6..95064fc8 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -21,6 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -125,7 +126,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { // nullable if there are no results var nullStream sql.NullInt32 - err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) if err == sql.ErrNoRows { err = nil } @@ -151,7 +152,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { for _, key := range keys { now := time.Now().Unix() - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ) if err != nil { @@ -162,7 +163,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx } func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { - _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) return err } diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index a299861d..6e32838b 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -151,14 +152,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql. } for keyIDWithAlgo, keyJSON := range keys.KeyJSON { algo, keyID := keys.Split(keyIDWithAlgo) - _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( + _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), ) if err != nil { return nil, err } } - rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) if err != nil { return nil, err } @@ -180,14 +181,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( ) (map[string]json.RawMessage, error) { var keyID string var keyJSON string - err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) if err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, err } - _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) return map[string]json.RawMessage{ algorithm + ":" + keyID: json.RawMessage(keyJSON), }, err diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index e7ff9976..9112fc6e 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -97,7 +97,7 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { } func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { - _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) return err } @@ -156,7 +156,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { // nullable if there are no results var nullStream sql.NullInt32 - err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) if err == sql.ErrNoRows { err = nil } @@ -188,7 +188,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { for _, key := range keys { now := time.Now().Unix() - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ) if err != nil { diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index 1b6a74d6..185b8861 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -153,14 +154,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys( } for keyIDWithAlgo, keyJSON := range keys.KeyJSON { algo, keyID := keys.Split(keyIDWithAlgo) - _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( + _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), ) if err != nil { return nil, err } } - rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) if err != nil { return nil, err } @@ -182,14 +183,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( ) (map[string]json.RawMessage, error) { var keyID string var keyJSON string - err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) if err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, err } - _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) if err != nil { return nil, err } |