aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage/devices/sqlite3/devices_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/storage/devices/sqlite3/devices_table.go')
-rw-r--r--userapi/storage/devices/sqlite3/devices_table.go66
1 files changed, 41 insertions, 25 deletions
diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go
index 07ea5dca..ec52c64b 100644
--- a/userapi/storage/devices/sqlite3/devices_table.go
+++ b/userapi/storage/devices/sqlite3/devices_table.go
@@ -74,6 +74,7 @@ const deleteDevicesSQL = "" +
type devicesStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
@@ -87,6 +88,7 @@ type devicesStatements struct {
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
+ s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(devicesSchema)
if err != nil {
return
@@ -128,13 +130,19 @@ func (s *devicesStatements) insertDevice(
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
- countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
- insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
- if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
- return nil, err
- }
- sessionID++
- if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
+ err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
+ insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
+ if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
+ return err
+ }
+ sessionID++
+ if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
+ return err
+ }
+ return nil
+ })
+ if err != nil {
return nil, err
}
return &api.Device{
@@ -148,9 +156,11 @@ func (s *devicesStatements) insertDevice(
func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
- stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
- _, err := stmt.ExecContext(ctx, id, localpart)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
+ _, err := stmt.ExecContext(ctx, id, localpart)
+ return err
+ })
}
func (s *devicesStatements) deleteDevices(
@@ -161,31 +171,37 @@ func (s *devicesStatements) deleteDevices(
if err != nil {
return err
}
- stmt := sqlutil.TxStmt(txn, prep)
- params := make([]interface{}, len(devices)+1)
- params[0] = localpart
- for i, v := range devices {
- params[i+1] = v
- }
- params = append(params, params...)
- _, err = stmt.ExecContext(ctx, params...)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, prep)
+ params := make([]interface{}, len(devices)+1)
+ params[0] = localpart
+ for i, v := range devices {
+ params[i+1] = v
+ }
+ params = append(params, params...)
+ _, err = stmt.ExecContext(ctx, params...)
+ return err
+ })
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
- stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
- _, err := stmt.ExecContext(ctx, localpart)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
+ _, err := stmt.ExecContext(ctx, localpart)
+ return err
+ })
}
func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
- stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
- _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
+ _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
+ return err
+ })
}
func (s *devicesStatements) selectDeviceByToken(