aboutsummaryrefslogtreecommitdiff
path: root/serverkeyapi/storage/sqlite3/server_key_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'serverkeyapi/storage/sqlite3/server_key_table.go')
-rw-r--r--serverkeyapi/storage/sqlite3/server_key_table.go25
1 files changed, 15 insertions, 10 deletions
diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go
index 4f03dccb..423292a5 100644
--- a/serverkeyapi/storage/sqlite3/server_key_table.go
+++ b/serverkeyapi/storage/sqlite3/server_key_table.go
@@ -63,12 +63,14 @@ const upsertServerKeysSQL = "" +
type serverKeyStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt
}
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
s.db = db
+ s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(serverKeysSchema)
if err != nil {
return
@@ -136,16 +138,19 @@ func (s *serverKeyStatements) upsertServerKeys(
request gomatrixserverlib.PublicKeyLookupRequest,
key gomatrixserverlib.PublicKeyLookupResult,
) error {
- _, err := s.upsertServerKeysStmt.ExecContext(
- ctx,
- string(request.ServerName),
- string(request.KeyID),
- nameAndKeyID(request),
- key.ValidUntilTS,
- key.ExpiredTS,
- key.Key.Encode(),
- )
- return err
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ string(request.ServerName),
+ string(request.KeyID),
+ nameAndKeyID(request),
+ key.ValidUntilTS,
+ key.ExpiredTS,
+ key.Key.Encode(),
+ )
+ return err
+ })
}
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {