aboutsummaryrefslogtreecommitdiff
path: root/keyserver/storage/sqlite3/device_keys_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'keyserver/storage/sqlite3/device_keys_table.go')
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go12
1 files changed, 12 insertions, 0 deletions
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
index abe6636a..a4d71fe1 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/keyserver/storage/sqlite3/device_keys_table.go
@@ -58,6 +58,9 @@ const selectMaxStreamForUserSQL = "" +
const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
+const deleteAllDeviceKeysSQL = "" +
+ "DELETE FROM keyserver_device_keys WHERE user_id=$1"
+
type deviceKeysStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
@@ -65,6 +68,7 @@ type deviceKeysStatements struct {
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
+ deleteAllDeviceKeysStmt *sql.Stmt
}
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@@ -88,9 +92,17 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
+ if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
+ return nil, err
+ }
return s, nil
}
+func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
+ _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
+ return err
+}
+
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {