aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage/postgres/devices_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/storage/postgres/devices_table.go')
-rw-r--r--userapi/storage/postgres/devices_table.go78
1 files changed, 29 insertions, 49 deletions
diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go
index 64cc0b71..7bc5dc69 100644
--- a/userapi/storage/postgres/devices_table.go
+++ b/userapi/storage/postgres/devices_table.go
@@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -111,53 +112,32 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName
}
-func (s *devicesStatements) execSchema(db *sql.DB) error {
- _, err := db.Exec(devicesSchema)
- return err
-}
-
-func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
- if err = s.execSchema(db); err != nil {
- return
- }
- if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
- return
- }
- if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
- return
- }
- if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
- return
- }
- if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
- return
+func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
+ s := &devicesStatements{
+ serverName: serverName,
}
- if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
- return
- }
- if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
- return
- }
- if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
- return
- }
- if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
- return
- }
- if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
- return
- }
- if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
- return
+ _, err := db.Exec(devicesSchema)
+ if err != nil {
+ return nil, err
}
- s.serverName = server
- return
+ return s, sqlutil.StatementList{
+ {&s.insertDeviceStmt, insertDeviceSQL},
+ {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
+ {&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
+ {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
+ {&s.updateDeviceNameStmt, updateDeviceNameSQL},
+ {&s.deleteDeviceStmt, deleteDeviceSQL},
+ {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
+ {&s.deleteDevicesStmt, deleteDevicesSQL},
+ {&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
+ {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
+ }.Prepare(db)
}
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
-func (s *devicesStatements) insertDevice(
+func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
@@ -179,7 +159,7 @@ func (s *devicesStatements) insertDevice(
}
// deleteDevice removes a single device by id and user localpart.
-func (s *devicesStatements) deleteDevice(
+func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
@@ -189,7 +169,7 @@ func (s *devicesStatements) deleteDevice(
// deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed.
-func (s *devicesStatements) deleteDevices(
+func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
@@ -199,7 +179,7 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
-func (s *devicesStatements) deleteDevicesByLocalpart(
+func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
@@ -207,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
return err
}
-func (s *devicesStatements) updateDeviceName(
+func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
@@ -215,7 +195,7 @@ func (s *devicesStatements) updateDeviceName(
return err
}
-func (s *devicesStatements) selectDeviceByToken(
+func (s *devicesStatements) SelectDeviceByToken(
ctx context.Context, accessToken string,
) (*api.Device, error) {
var dev api.Device
@@ -231,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
-func (s *devicesStatements) selectDeviceByID(
+func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
@@ -248,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID(
return &dev, err
}
-func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
+func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
if err != nil {
return nil, err
@@ -271,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
return devices, rows.Err()
}
-func (s *devicesStatements) selectDevicesByLocalpart(
+func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
@@ -313,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, rows.Err()
}
-func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
+func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)