diff options
Diffstat (limited to 'userapi/storage/postgres/devices_table.go')
-rw-r--r-- | userapi/storage/postgres/devices_table.go | 78 |
1 files changed, 29 insertions, 49 deletions
diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 64cc0b71..7bc5dc69 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -111,53 +112,32 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } -func (s *devicesStatements) execSchema(db *sql.DB) error { - _, err := db.Exec(devicesSchema) - return err -} - -func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - if err = s.execSchema(db); err != nil { - return - } - if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { - return - } - if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { - return - } - if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { - return - } - if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { - return +func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) { + s := &devicesStatements{ + serverName: serverName, } - if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { - return - } - if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { - return - } - if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { - return - } - if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { - return - } - if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { - return - } - if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil { - return + _, err := db.Exec(devicesSchema) + if err != nil { + return nil, err } - s.serverName = server - return + return s, sqlutil.StatementList{ + {&s.insertDeviceStmt, insertDeviceSQL}, + {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL}, + {&s.selectDeviceByIDStmt, selectDeviceByIDSQL}, + {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL}, + {&s.updateDeviceNameStmt, updateDeviceNameSQL}, + {&s.deleteDeviceStmt, deleteDeviceSQL}, + {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL}, + {&s.deleteDevicesStmt, deleteDevicesSQL}, + {&s.selectDevicesByIDStmt, selectDevicesByIDSQL}, + {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen}, + }.Prepare(db) } // insertDevice creates a new device. Returns an error if any device with the same access token already exists. // Returns an error if the user already has a device with the given device ID. // Returns the device on success. -func (s *devicesStatements) insertDevice( +func (s *devicesStatements) InsertDevice( ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { @@ -179,7 +159,7 @@ func (s *devicesStatements) insertDevice( } // deleteDevice removes a single device by id and user localpart. -func (s *devicesStatements) deleteDevice( +func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) @@ -189,7 +169,7 @@ func (s *devicesStatements) deleteDevice( // deleteDevices removes a single or multiple devices by ids and user localpart. // Returns an error if the execution failed. -func (s *devicesStatements) deleteDevices( +func (s *devicesStatements) DeleteDevices( ctx context.Context, txn *sql.Tx, localpart string, devices []string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) @@ -199,7 +179,7 @@ func (s *devicesStatements) deleteDevices( // deleteDevicesByLocalpart removes all devices for the // given user localpart. -func (s *devicesStatements) deleteDevicesByLocalpart( +func (s *devicesStatements) DeleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) @@ -207,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart( return err } -func (s *devicesStatements) updateDeviceName( +func (s *devicesStatements) UpdateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) @@ -215,7 +195,7 @@ func (s *devicesStatements) updateDeviceName( return err } -func (s *devicesStatements) selectDeviceByToken( +func (s *devicesStatements) SelectDeviceByToken( ctx context.Context, accessToken string, ) (*api.Device, error) { var dev api.Device @@ -231,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID -func (s *devicesStatements) selectDeviceByID( +func (s *devicesStatements) SelectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device @@ -248,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID( return &dev, err } -func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { +func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs)) if err != nil { return nil, err @@ -271,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s return devices, rows.Err() } -func (s *devicesStatements) selectDevicesByLocalpart( +func (s *devicesStatements) SelectDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} @@ -313,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, rows.Err() } -func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) |