aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
Diffstat (limited to 'userapi')
-rw-r--r--userapi/storage/accounts/postgres/storage.go25
-rw-r--r--userapi/storage/accounts/sqlite3/account_data_table.go6
-rw-r--r--userapi/storage/accounts/sqlite3/accounts_table.go6
-rw-r--r--userapi/storage/accounts/sqlite3/profile_table.go6
-rw-r--r--userapi/storage/accounts/sqlite3/storage.go33
-rw-r--r--userapi/storage/accounts/sqlite3/threepid_table.go6
-rw-r--r--userapi/storage/devices/sqlite3/devices_table.go68
-rw-r--r--userapi/storage/devices/sqlite3/storage.go18
8 files changed, 76 insertions, 92 deletions
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go
index 9653c019..b36264dd 100644
--- a/userapi/storage/accounts/postgres/storage.go
+++ b/userapi/storage/accounts/postgres/storage.go
@@ -34,7 +34,8 @@ import (
// Database represents an account database
type Database struct {
- db *sql.DB
+ db *sql.DB
+ writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
@@ -49,27 +50,27 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, err
}
- partitions := sqlutil.PartitionOffsetStatements{}
- if err = partitions.Prepare(db, "account"); err != nil {
+ d := &Database{
+ serverName: serverName,
+ db: db,
+ writer: sqlutil.NewDummyWriter(),
+ }
+ if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
- a := accountsStatements{}
- if err = a.prepare(db, serverName); err != nil {
+ if err = d.accounts.prepare(db, serverName); err != nil {
return nil, err
}
- p := profilesStatements{}
- if err = p.prepare(db); err != nil {
+ if err = d.profiles.prepare(db); err != nil {
return nil, err
}
- ac := accountDataStatements{}
- if err = ac.prepare(db); err != nil {
+ if err = d.accountDatas.prepare(db); err != nil {
return nil, err
}
- t := threepidStatements{}
- if err = t.prepare(db); err != nil {
+ if err = d.threepids.prepare(db); err != nil {
return nil, err
}
- return &Database{db, partitions, a, p, ac, t, serverName}, nil
+ return d, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go
index 9b40e657..aee8db6e 100644
--- a/userapi/storage/accounts/sqlite3/account_data_table.go
+++ b/userapi/storage/accounts/sqlite3/account_data_table.go
@@ -51,15 +51,15 @@ const selectAccountDataByTypeSQL = "" +
type accountDataStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
insertAccountDataStmt *sql.Stmt
selectAccountDataStmt *sql.Stmt
selectAccountDataByTypeStmt *sql.Stmt
}
-func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
+func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(accountDataSchema)
if err != nil {
return
diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go
index 586bcab9..83b90668 100644
--- a/userapi/storage/accounts/sqlite3/accounts_table.go
+++ b/userapi/storage/accounts/sqlite3/accounts_table.go
@@ -59,7 +59,7 @@ const selectNewNumericLocalpartSQL = "" +
type accountsStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
@@ -67,9 +67,9 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName
}
-func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
+func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(accountsSchema)
if err != nil {
return
diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go
index cd35d298..1ec45e03 100644
--- a/userapi/storage/accounts/sqlite3/profile_table.go
+++ b/userapi/storage/accounts/sqlite3/profile_table.go
@@ -53,7 +53,7 @@ const selectProfilesBySearchSQL = "" +
type profilesStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt
@@ -61,9 +61,9 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt
}
-func (s *profilesStatements) prepare(db *sql.DB) (err error) {
+func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(profilesSchema)
if err != nil {
return
diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go
index 4d2c5e51..4f45f754 100644
--- a/userapi/storage/accounts/sqlite3/storage.go
+++ b/userapi/storage/accounts/sqlite3/storage.go
@@ -33,7 +33,9 @@ import (
// Database represents an account database
type Database struct {
- db *sql.DB
+ db *sql.DB
+ writer sqlutil.Writer
+
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
@@ -53,35 +55,28 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, err
}
+ d := &Database{
+ serverName: serverName,
+ db: db,
+ writer: sqlutil.NewExclusiveWriter(),
+ }
partitions := sqlutil.PartitionOffsetStatements{}
- if err = partitions.Prepare(db, "account"); err != nil {
+ if err = partitions.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
- a := accountsStatements{}
- if err = a.prepare(db, serverName); err != nil {
+ if err = d.accounts.prepare(db, d.writer, serverName); err != nil {
return nil, err
}
- p := profilesStatements{}
- if err = p.prepare(db); err != nil {
+ if err = d.profiles.prepare(db, d.writer); err != nil {
return nil, err
}
- ac := accountDataStatements{}
- if err = ac.prepare(db); err != nil {
+ if err = d.accountDatas.prepare(db, d.writer); err != nil {
return nil, err
}
- t := threepidStatements{}
- if err = t.prepare(db); err != nil {
+ if err = d.threepids.prepare(db, d.writer); err != nil {
return nil, err
}
- return &Database{
- db: db,
- PartitionOffsetStatements: partitions,
- accounts: a,
- profiles: p,
- accountDatas: ac,
- threepids: t,
- serverName: serverName,
- }, nil
+ return d, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go
index 3000d7c4..230978fe 100644
--- a/userapi/storage/accounts/sqlite3/threepid_table.go
+++ b/userapi/storage/accounts/sqlite3/threepid_table.go
@@ -54,16 +54,16 @@ const deleteThreePIDSQL = "" +
type threepidStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
selectLocalpartForThreePIDStmt *sql.Stmt
selectThreePIDsForLocalpartStmt *sql.Stmt
insertThreePIDStmt *sql.Stmt
deleteThreePIDStmt *sql.Stmt
}
-func (s *threepidStatements) prepare(db *sql.DB) (err error) {
+func (s *threepidStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(threepidSchema)
if err != nil {
return
diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go
index 962e63b0..c93e8b77 100644
--- a/userapi/storage/devices/sqlite3/devices_table.go
+++ b/userapi/storage/devices/sqlite3/devices_table.go
@@ -78,7 +78,7 @@ const selectDevicesByIDSQL = "" +
type devicesStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
@@ -91,9 +91,9 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName
}
-func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
+func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(devicesSchema)
if err != nil {
return
@@ -138,19 +138,13 @@ func (s *devicesStatements) insertDevice(
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
- err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
- insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
- if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
- return err
- }
- sessionID++
- if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
- return err
- }
- return nil
- })
- if err != nil {
+ countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
+ insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
+ if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
+ return nil, err
+ }
+ sessionID++
+ if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
return nil, err
}
return &api.Device{
@@ -164,11 +158,9 @@ func (s *devicesStatements) insertDevice(
func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
- _, err := stmt.ExecContext(ctx, id, localpart)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
+ _, err := stmt.ExecContext(ctx, id, localpart)
+ return err
}
func (s *devicesStatements) deleteDevices(
@@ -179,36 +171,30 @@ func (s *devicesStatements) deleteDevices(
if err != nil {
return err
}
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, prep)
- params := make([]interface{}, len(devices)+1)
- params[0] = localpart
- for i, v := range devices {
- params[i+1] = v
- }
- _, err = stmt.ExecContext(ctx, params...)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, prep)
+ params := make([]interface{}, len(devices)+1)
+ params[0] = localpart
+ for i, v := range devices {
+ params[i+1] = v
+ }
+ _, err = stmt.ExecContext(ctx, params...)
+ return err
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
- _, err := stmt.ExecContext(ctx, localpart)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
+ _, err := stmt.ExecContext(ctx, localpart)
+ return err
}
func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
- _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
+ _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
+ return err
}
func (s *devicesStatements) selectDeviceByToken(
diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go
index 1f2b59f3..4f426c6e 100644
--- a/userapi/storage/devices/sqlite3/storage.go
+++ b/userapi/storage/devices/sqlite3/storage.go
@@ -34,6 +34,7 @@ var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
+ writer sqlutil.Writer
devices devicesStatements
}
@@ -43,11 +44,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, err
}
+ writer := sqlutil.NewExclusiveWriter()
d := devicesStatements{}
- if err = d.prepare(db, serverName); err != nil {
+ if err = d.prepare(db, writer, serverName); err != nil {
return nil, err
}
- return &Database{db, d}, nil
+ return &Database{db, writer, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
@@ -88,7 +90,7 @@ func (d *Database) CreateDevice(
displayName *string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
- returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
@@ -108,7 +110,7 @@ func (d *Database) CreateDevice(
return
}
- returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
@@ -138,7 +140,7 @@ func generateDeviceID() (string, error) {
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
@@ -150,7 +152,7 @@ func (d *Database) UpdateDevice(
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
@@ -165,7 +167,7 @@ func (d *Database) RemoveDevice(
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
@@ -179,7 +181,7 @@ func (d *Database) RemoveDevices(
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}