diff options
Diffstat (limited to 'userapi')
-rw-r--r-- | userapi/storage/accounts/postgres/storage.go | 25 | ||||
-rw-r--r-- | userapi/storage/accounts/sqlite3/account_data_table.go | 6 | ||||
-rw-r--r-- | userapi/storage/accounts/sqlite3/accounts_table.go | 6 | ||||
-rw-r--r-- | userapi/storage/accounts/sqlite3/profile_table.go | 6 | ||||
-rw-r--r-- | userapi/storage/accounts/sqlite3/storage.go | 33 | ||||
-rw-r--r-- | userapi/storage/accounts/sqlite3/threepid_table.go | 6 | ||||
-rw-r--r-- | userapi/storage/devices/sqlite3/devices_table.go | 68 | ||||
-rw-r--r-- | userapi/storage/devices/sqlite3/storage.go | 18 |
8 files changed, 76 insertions, 92 deletions
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 9653c019..b36264dd 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -34,7 +34,8 @@ import ( // Database represents an account database type Database struct { - db *sql.DB + db *sql.DB + writer sqlutil.Writer sqlutil.PartitionOffsetStatements accounts accountsStatements profiles profilesStatements @@ -49,27 +50,27 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, err } - partitions := sqlutil.PartitionOffsetStatements{} - if err = partitions.Prepare(db, "account"); err != nil { + d := &Database{ + serverName: serverName, + db: db, + writer: sqlutil.NewDummyWriter(), + } + if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil { return nil, err } - a := accountsStatements{} - if err = a.prepare(db, serverName); err != nil { + if err = d.accounts.prepare(db, serverName); err != nil { return nil, err } - p := profilesStatements{} - if err = p.prepare(db); err != nil { + if err = d.profiles.prepare(db); err != nil { return nil, err } - ac := accountDataStatements{} - if err = ac.prepare(db); err != nil { + if err = d.accountDatas.prepare(db); err != nil { return nil, err } - t := threepidStatements{} - if err = t.prepare(db); err != nil { + if err = d.threepids.prepare(db); err != nil { return nil, err } - return &Database{db, partitions, a, p, ac, t, serverName}, nil + return d, nil } // GetAccountByPassword returns the account associated with the given localpart and password. diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index 9b40e657..aee8db6e 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -51,15 +51,15 @@ const selectAccountDataByTypeSQL = "" + type accountDataStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer insertAccountDataStmt *sql.Stmt selectAccountDataStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { +func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(accountDataSchema) if err != nil { return diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 586bcab9..83b90668 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -59,7 +59,7 @@ const selectNewNumericLocalpartSQL = "" + type accountsStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer insertAccountStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt @@ -67,9 +67,9 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { +func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(accountsSchema) if err != nil { return diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index cd35d298..1ec45e03 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -53,7 +53,7 @@ const selectProfilesBySearchSQL = "" + type profilesStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer insertProfileStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt setAvatarURLStmt *sql.Stmt @@ -61,9 +61,9 @@ type profilesStatements struct { selectProfilesBySearchStmt *sql.Stmt } -func (s *profilesStatements) prepare(db *sql.DB) (err error) { +func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(profilesSchema) if err != nil { return diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 4d2c5e51..4f45f754 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -33,7 +33,9 @@ import ( // Database represents an account database type Database struct { - db *sql.DB + db *sql.DB + writer sqlutil.Writer + sqlutil.PartitionOffsetStatements accounts accountsStatements profiles profilesStatements @@ -53,35 +55,28 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, err } + d := &Database{ + serverName: serverName, + db: db, + writer: sqlutil.NewExclusiveWriter(), + } partitions := sqlutil.PartitionOffsetStatements{} - if err = partitions.Prepare(db, "account"); err != nil { + if err = partitions.Prepare(db, d.writer, "account"); err != nil { return nil, err } - a := accountsStatements{} - if err = a.prepare(db, serverName); err != nil { + if err = d.accounts.prepare(db, d.writer, serverName); err != nil { return nil, err } - p := profilesStatements{} - if err = p.prepare(db); err != nil { + if err = d.profiles.prepare(db, d.writer); err != nil { return nil, err } - ac := accountDataStatements{} - if err = ac.prepare(db); err != nil { + if err = d.accountDatas.prepare(db, d.writer); err != nil { return nil, err } - t := threepidStatements{} - if err = t.prepare(db); err != nil { + if err = d.threepids.prepare(db, d.writer); err != nil { return nil, err } - return &Database{ - db: db, - PartitionOffsetStatements: partitions, - accounts: a, - profiles: p, - accountDatas: ac, - threepids: t, - serverName: serverName, - }, nil + return d, nil } // GetAccountByPassword returns the account associated with the given localpart and password. diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go index 3000d7c4..230978fe 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -54,16 +54,16 @@ const deleteThreePIDSQL = "" + type threepidStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer selectLocalpartForThreePIDStmt *sql.Stmt selectThreePIDsForLocalpartStmt *sql.Stmt insertThreePIDStmt *sql.Stmt deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB) (err error) { +func (s *threepidStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(threepidSchema) if err != nil { return diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index 962e63b0..c93e8b77 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -78,7 +78,7 @@ const selectDevicesByIDSQL = "" + type devicesStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer insertDeviceStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt @@ -91,9 +91,9 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } -func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { +func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(devicesSchema) if err != nil { return @@ -138,19 +138,13 @@ func (s *devicesStatements) insertDevice( ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) - insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) - if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { - return err - } - sessionID++ - if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { - return err - } - return nil - }) - if err != nil { + countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) + if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { + return nil, err + } + sessionID++ + if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { return nil, err } return &api.Device{ @@ -164,11 +158,9 @@ func (s *devicesStatements) insertDevice( func (s *devicesStatements) deleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) - return err - }) + stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) + _, err := stmt.ExecContext(ctx, id, localpart) + return err } func (s *devicesStatements) deleteDevices( @@ -179,36 +171,30 @@ func (s *devicesStatements) deleteDevices( if err != nil { return err } - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, prep) - params := make([]interface{}, len(devices)+1) - params[0] = localpart - for i, v := range devices { - params[i+1] = v - } - _, err = stmt.ExecContext(ctx, params...) - return err - }) + stmt := sqlutil.TxStmt(txn, prep) + params := make([]interface{}, len(devices)+1) + params[0] = localpart + for i, v := range devices { + params[i+1] = v + } + _, err = stmt.ExecContext(ctx, params...) + return err } func (s *devicesStatements) deleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart) - return err - }) + stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + _, err := stmt.ExecContext(ctx, localpart) + return err } func (s *devicesStatements) updateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) - return err - }) + stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) + _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + return err } func (s *devicesStatements) selectDeviceByToken( diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 1f2b59f3..4f426c6e 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -34,6 +34,7 @@ var deviceIDByteLength = 6 // Database represents a device database. type Database struct { db *sql.DB + writer sqlutil.Writer devices devicesStatements } @@ -43,11 +44,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, err } + writer := sqlutil.NewExclusiveWriter() d := devicesStatements{} - if err = d.prepare(db, serverName); err != nil { + if err = d.prepare(db, writer, serverName); err != nil { return nil, err } - return &Database{db, d}, nil + return &Database{db, writer, d}, nil } // GetDeviceByAccessToken returns the device matching the given access token. @@ -88,7 +90,7 @@ func (d *Database) CreateDevice( displayName *string, ) (dev *api.Device, returnErr error) { if deviceID != nil { - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { var err error // Revoke existing tokens for this device if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { @@ -108,7 +110,7 @@ func (d *Database) CreateDevice( return } - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { var err error dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) return err @@ -138,7 +140,7 @@ func generateDeviceID() (string, error) { func (d *Database) UpdateDevice( ctx context.Context, localpart, deviceID string, displayName *string, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) }) } @@ -150,7 +152,7 @@ func (d *Database) UpdateDevice( func (d *Database) RemoveDevice( ctx context.Context, deviceID, localpart string, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { return err } @@ -165,7 +167,7 @@ func (d *Database) RemoveDevice( func (d *Database) RemoveDevices( ctx context.Context, localpart string, devices []string, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { return err } @@ -179,7 +181,7 @@ func (d *Database) RemoveDevices( func (d *Database) RemoveAllDevices( ctx context.Context, localpart string, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { return err } |