aboutsummaryrefslogtreecommitdiff
path: root/keyserver
diff options
context:
space:
mode:
authorSam <me@samcday.com>2020-09-24 12:10:14 +0200
committerGitHub <noreply@github.com>2020-09-24 11:10:14 +0100
commita6700331cee70a3ca04c834efdc68cc2ef63947d (patch)
tree10420639a2a4685b2562ff1d106d1ae9a11712d4 /keyserver
parent60524f4b994ba070fff74fa47599e0fcc7856fc0 (diff)
Update all usages of tx.Stmt to sqlutil.TxStmt (#1423)
* Replace all usages of txn.Stmt with sqlutil.TxStmt Signed-off-by: Sam Day <me@samcday.com> * Fix sign off link in PR template. Signed-off-by: Sam Day <me@samcday.com> Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
Diffstat (limited to 'keyserver')
-rw-r--r--keyserver/storage/postgres/device_keys_table.go7
-rw-r--r--keyserver/storage/postgres/one_time_keys_table.go9
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go6
-rw-r--r--keyserver/storage/sqlite3/one_time_keys_table.go9
4 files changed, 17 insertions, 14 deletions
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go
index e5bec8f6..95064fc8 100644
--- a/keyserver/storage/postgres/device_keys_table.go
+++ b/keyserver/storage/postgres/device_keys_table.go
@@ -21,6 +21,7 @@ import (
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@@ -125,7 +126,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
// nullable if there are no results
var nullStream sql.NullInt32
- err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
+ err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows {
err = nil
}
@@ -151,7 +152,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys {
now := time.Now().Unix()
- _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
+ _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
)
if err != nil {
@@ -162,7 +163,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
}
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
- _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
+ _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err
}
diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go
index a299861d..6e32838b 100644
--- a/keyserver/storage/postgres/one_time_keys_table.go
+++ b/keyserver/storage/postgres/one_time_keys_table.go
@@ -21,6 +21,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@@ -151,14 +152,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.
}
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
- _, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
+ _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
- rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
+ rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
if err != nil {
return nil, err
}
@@ -180,14 +181,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
- err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
+ err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
- _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
+ _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
index e7ff9976..9112fc6e 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/keyserver/storage/sqlite3/device_keys_table.go
@@ -97,7 +97,7 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
}
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
- _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
+ _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err
}
@@ -156,7 +156,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
// nullable if there are no results
var nullStream sql.NullInt32
- err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
+ err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows {
err = nil
}
@@ -188,7 +188,7 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys {
now := time.Now().Unix()
- _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
+ _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
)
if err != nil {
diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go
index 1b6a74d6..185b8861 100644
--- a/keyserver/storage/sqlite3/one_time_keys_table.go
+++ b/keyserver/storage/sqlite3/one_time_keys_table.go
@@ -21,6 +21,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@@ -153,14 +154,14 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(
}
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
- _, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
+ _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
- rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
+ rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
if err != nil {
return nil, err
}
@@ -182,14 +183,14 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
- err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
+ err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
- _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
+ _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
if err != nil {
return nil, err
}