diff options
Diffstat (limited to 'keyserver/storage/sqlite3/device_keys_table.go')
-rw-r--r-- | keyserver/storage/sqlite3/device_keys_table.go | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index abe6636a..a4d71fe1 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -58,6 +58,9 @@ const selectMaxStreamForUserSQL = "" + const countStreamIDsForUserSQL = "" + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" +const deleteAllDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1" + type deviceKeysStatements struct { db *sql.DB writer *sqlutil.TransactionWriter @@ -65,6 +68,7 @@ type deviceKeysStatements struct { selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt } func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -88,9 +92,17 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } + if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { + return nil, err + } return s, nil } +func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + return err +} + func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { deviceIDMap := make(map[string]bool) for _, d := range deviceIDs { |