diff options
Diffstat (limited to 'userapi/storage/devices/sqlite3/devices_table.go')
-rw-r--r-- | userapi/storage/devices/sqlite3/devices_table.go | 66 |
1 files changed, 41 insertions, 25 deletions
diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index 07ea5dca..ec52c64b 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -74,6 +74,7 @@ const deleteDevicesSQL = "" + type devicesStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter insertDeviceStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt @@ -87,6 +88,7 @@ type devicesStatements struct { func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { s.db = db + s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(devicesSchema) if err != nil { return @@ -128,13 +130,19 @@ func (s *devicesStatements) insertDevice( ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 - countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) - insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) - if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { - return nil, err - } - sessionID++ - if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { + err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) + if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { + return err + } + sessionID++ + if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { + return err + } + return nil + }) + if err != nil { return nil, err } return &api.Device{ @@ -148,9 +156,11 @@ func (s *devicesStatements) insertDevice( func (s *devicesStatements) deleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) + _, err := stmt.ExecContext(ctx, id, localpart) + return err + }) } func (s *devicesStatements) deleteDevices( @@ -161,31 +171,37 @@ func (s *devicesStatements) deleteDevices( if err != nil { return err } - stmt := sqlutil.TxStmt(txn, prep) - params := make([]interface{}, len(devices)+1) - params[0] = localpart - for i, v := range devices { - params[i+1] = v - } - params = append(params, params...) - _, err = stmt.ExecContext(ctx, params...) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, prep) + params := make([]interface{}, len(devices)+1) + params[0] = localpart + for i, v := range devices { + params[i+1] = v + } + params = append(params, params...) + _, err = stmt.ExecContext(ctx, params...) + return err + }) } func (s *devicesStatements) deleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart string, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + _, err := stmt.ExecContext(ctx, localpart) + return err + }) } func (s *devicesStatements) updateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { - stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) + _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + return err + }) } func (s *devicesStatements) selectDeviceByToken( |