aboutsummaryrefslogtreecommitdiff
path: root/keyserver/storage
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2020-08-05 11:01:37 +0100
committerGitHub <noreply@github.com>2020-08-05 11:01:37 +0100
commit15dc1f4d0361da736339653ca8e6ba26ed103792 (patch)
tree1a52345f1ca1a4a9240cf7993f533d78bbc8c664 /keyserver/storage
parent22f028e141297bcd8b1230573e55d5790e0d67a4 (diff)
Use TransactionWriter in SQLite keyserver (#1239)
* Use TransactionWriter in SQLite keyserver * Fix keyserver storage tests
Diffstat (limited to 'keyserver/storage')
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go25
-rw-r--r--keyserver/storage/sqlite3/key_changes_table.go11
-rw-r--r--keyserver/storage/sqlite3/one_time_keys_table.go23
-rw-r--r--keyserver/storage/storage_test.go44
4 files changed, 65 insertions, 38 deletions
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
index 9f70885a..900d1238 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/keyserver/storage/sqlite3/device_keys_table.go
@@ -20,6 +20,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@@ -54,6 +55,7 @@ const selectMaxStreamForUserSQL = "" +
type deviceKeysStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
@@ -62,7 +64,8 @@ type deviceKeysStatements struct {
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
s := &deviceKeysStatements{
- db: db,
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(deviceKeysSchema)
if err != nil {
@@ -141,14 +144,16 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
}
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
- for _, key := range keys {
- now := time.Now().Unix()
- _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
- ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
- )
- if err != nil {
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ for _, key := range keys {
+ now := time.Now().Unix()
+ _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
+ ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
+ )
+ if err != nil {
+ return err
+ }
}
- }
- return nil
+ return nil
+ })
}
diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go
index 32721eae..02b9d193 100644
--- a/keyserver/storage/sqlite3/key_changes_table.go
+++ b/keyserver/storage/sqlite3/key_changes_table.go
@@ -21,6 +21,7 @@ import (
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@@ -51,13 +52,15 @@ const selectKeyChangesSQL = "" +
type keyChangesStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt
}
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
s := &keyChangesStatements{
- db: db,
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(keyChangesSchema)
if err != nil {
@@ -73,8 +76,10 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
}
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
- _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
- return err
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
+ return err
+ })
}
func (s *keyChangesStatements) SelectKeyChanges(
diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go
index b35407cd..f910479f 100644
--- a/keyserver/storage/sqlite3/one_time_keys_table.go
+++ b/keyserver/storage/sqlite3/one_time_keys_table.go
@@ -60,6 +60,7 @@ const selectKeyByAlgorithmSQL = "" +
type oneTimeKeysStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
upsertKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt
selectKeysCountStmt *sql.Stmt
@@ -69,7 +70,8 @@ type oneTimeKeysStatements struct {
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
s := &oneTimeKeysStatements{
- db: db,
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(oneTimeKeysSchema)
if err != nil {
@@ -150,7 +152,7 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
UserID: keys.UserID,
KeyCount: make(map[string]int),
}
- return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
+ return counts, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := txn.Stmt(s.upsertKeysStmt).ExecContext(
@@ -183,14 +185,17 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
- err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, nil
+ err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil
+ }
+ return err
}
- return nil, err
- }
- _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
+ _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
+ return err
+ })
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go
index b3e45e6c..949d9dd6 100644
--- a/keyserver/storage/storage_test.go
+++ b/keyserver/storage/storage_test.go
@@ -2,6 +2,10 @@ package storage
import (
"context"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
"reflect"
"testing"
@@ -11,6 +15,21 @@ import (
var ctx = context.Background()
+func MustCreateDatabase(t *testing.T) (Database, func()) {
+ tmpfile, err := ioutil.TempFile("", "keyserver_storage_test")
+ if err != nil {
+ log.Fatal(err)
+ }
+ t.Logf("Database %s", tmpfile.Name())
+ db, err := NewDatabase(fmt.Sprintf("file://%s", tmpfile.Name()), nil)
+ if err != nil {
+ t.Fatalf("Failed to NewDatabase: %s", err)
+ }
+ return db, func() {
+ os.Remove(tmpfile.Name())
+ }
+}
+
func MustNotError(t *testing.T, err error) {
t.Helper()
if err == nil {
@@ -20,10 +39,8 @@ func MustNotError(t *testing.T, err error) {
}
func TestKeyChanges(t *testing.T) {
- db, err := NewDatabase("file::memory:", nil)
- if err != nil {
- t.Fatalf("Failed to NewDatabase: %s", err)
- }
+ db, clean := MustCreateDatabase(t)
+ defer clean()
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
@@ -40,10 +57,8 @@ func TestKeyChanges(t *testing.T) {
}
func TestKeyChangesNoDupes(t *testing.T) {
- db, err := NewDatabase("file::memory:", nil)
- if err != nil {
- t.Fatalf("Failed to NewDatabase: %s", err)
- }
+ db, clean := MustCreateDatabase(t)
+ defer clean()
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
@@ -60,10 +75,8 @@ func TestKeyChangesNoDupes(t *testing.T) {
}
func TestKeyChangesUpperLimit(t *testing.T) {
- db, err := NewDatabase("file::memory:", nil)
- if err != nil {
- t.Fatalf("Failed to NewDatabase: %s", err)
- }
+ db, clean := MustCreateDatabase(t)
+ defer clean()
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
@@ -82,10 +95,9 @@ func TestKeyChangesUpperLimit(t *testing.T) {
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
// and that they are returned correctly when querying for device keys.
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
- db, err := NewDatabase("file::memory:", nil)
- if err != nil {
- t.Fatalf("Failed to NewDatabase: %s", err)
- }
+ var err error
+ db, clean := MustCreateDatabase(t)
+ defer clean()
alice := "@alice:TestDeviceKeysStreamIDGeneration"
bob := "@bob:TestDeviceKeysStreamIDGeneration"
msgs := []api.DeviceMessage{