aboutsummaryrefslogtreecommitdiff
path: root/keyserver/storage/postgres/device_keys_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'keyserver/storage/postgres/device_keys_table.go')
-rw-r--r--keyserver/storage/postgres/device_keys_table.go33
1 files changed, 23 insertions, 10 deletions
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go
index 5ae0da96..628301cf 100644
--- a/keyserver/storage/postgres/device_keys_table.go
+++ b/keyserver/storage/postgres/device_keys_table.go
@@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" +
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
+const selectBatchDeviceKeysWithEmptiesSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
+
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@@ -69,14 +72,15 @@ const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
- db *sql.DB
- upsertDeviceKeysStmt *sql.Stmt
- selectDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysStmt *sql.Stmt
- selectMaxStreamForUserStmt *sql.Stmt
- countStreamIDsForUserStmt *sql.Stmt
- deleteDeviceKeysStmt *sql.Stmt
- deleteAllDeviceKeysStmt *sql.Stmt
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
+ countStreamIDsForUserStmt *sql.Stmt
+ deleteDeviceKeysStmt *sql.Stmt
+ deleteAllDeviceKeysStmt *sql.Stmt
}
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
+ if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
+ return nil, err
+ }
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
@@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
return err
}
-func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
- rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
+ var stmt *sql.Stmt
+ if includeEmpty {
+ stmt = s.selectBatchDeviceKeysWithEmptiesStmt
+ } else {
+ stmt = s.selectBatchDeviceKeysStmt
+ }
+ rows, err := stmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}