diff options
Diffstat (limited to 'userapi/storage/devices/sqlite3/devices_table.go')
-rw-r--r-- | userapi/storage/devices/sqlite3/devices_table.go | 43 |
1 files changed, 42 insertions, 1 deletions
diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index ec52c64b..efe6f927 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -20,6 +20,7 @@ import ( "strings" "time" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" @@ -72,6 +73,9 @@ const deleteDevicesByLocalpartSQL = "" + const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" +const selectDevicesByIDSQL = "" + + "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" + type devicesStatements struct { db *sql.DB writer *sqlutil.TransactionWriter @@ -79,6 +83,7 @@ type devicesStatements struct { selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt + selectDevicesByIDStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt deleteDeviceStmt *sql.Stmt @@ -117,6 +122,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { return } + if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { + return + } s.serverName = server return } @@ -224,11 +232,15 @@ func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device + var displayName sql.NullString stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) + if displayName.Valid { + dev.DisplayName = displayName.String + } } return &dev, err } @@ -263,3 +275,32 @@ func (s *devicesStatements) selectDevicesByLocalpart( return devices, nil } + +func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1) + iDeviceIDs := make([]interface{}, len(deviceIDs)) + for i := range deviceIDs { + iDeviceIDs[i] = deviceIDs[i] + } + + rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") + var devices []api.Device + for rows.Next() { + var dev api.Device + var localpart string + var displayName sql.NullString + if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dev.DisplayName = displayName.String + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + devices = append(devices, dev) + } + return devices, rows.Err() +} |