aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage/devices/sqlite3/devices_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/storage/devices/sqlite3/devices_table.go')
-rw-r--r--userapi/storage/devices/sqlite3/devices_table.go43
1 files changed, 42 insertions, 1 deletions
diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go
index ec52c64b..efe6f927 100644
--- a/userapi/storage/devices/sqlite3/devices_table.go
+++ b/userapi/storage/devices/sqlite3/devices_table.go
@@ -20,6 +20,7 @@ import (
"strings"
"time"
+ "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
@@ -72,6 +73,9 @@ const deleteDevicesByLocalpartSQL = "" +
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
+const selectDevicesByIDSQL = "" +
+ "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
+
type devicesStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
@@ -79,6 +83,7 @@ type devicesStatements struct {
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt
+ selectDevicesByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
@@ -117,6 +122,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
+ if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
+ return
+ }
s.serverName = server
return
}
@@ -224,11 +232,15 @@ func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
+ var displayName sql.NullString
stmt := s.selectDeviceByIDStmt
- err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
+ err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ if displayName.Valid {
+ dev.DisplayName = displayName.String
+ }
}
return &dev, err
}
@@ -263,3 +275,32 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, nil
}
+
+func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
+ sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
+ iDeviceIDs := make([]interface{}, len(deviceIDs))
+ for i := range deviceIDs {
+ iDeviceIDs[i] = deviceIDs[i]
+ }
+
+ rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
+ var devices []api.Device
+ for rows.Next() {
+ var dev api.Device
+ var localpart string
+ var displayName sql.NullString
+ if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
+ return nil, err
+ }
+ if displayName.Valid {
+ dev.DisplayName = displayName.String
+ }
+ dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ devices = append(devices, dev)
+ }
+ return devices, rows.Err()
+}