aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
authordevonh <devon.dmytro@gmail.com>2023-10-25 08:13:18 +0000
committerGitHub <noreply@github.com>2023-10-25 10:13:18 +0200
commita0375d41fbbdabd98df743d2e7fa77b4d0c44d4b (patch)
treee4fe57b63372abdb73fb42d688a2648803ea4ea5 /userapi
parente02a7948d8556398ceb345a241c175b5ca1d011f (diff)
Add simple test for one time keys (#3239)
Diffstat (limited to 'userapi')
-rw-r--r--userapi/storage/storage_test.go51
1 files changed, 51 insertions, 0 deletions
diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go
index a46ee9eb..5a789dfd 100644
--- a/userapi/storage/storage_test.go
+++ b/userapi/storage/storage_test.go
@@ -1,6 +1,7 @@
package storage_test
import (
+ "bytes"
"context"
"encoding/json"
"fmt"
@@ -758,3 +759,53 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
}
})
}
+
+func TestOneTimeKeys(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ userID := "@alice:localhost"
+ deviceID := "alice_device"
+ otk := api.OneTimeKeys{
+ UserID: userID,
+ DeviceID: deviceID,
+ KeyJSON: map[string]json.RawMessage{"curve25519:KEY1": []byte(`{"key":"v1"}`)},
+ }
+
+ // Add a one time key to the DB
+ _, err := db.StoreOneTimeKeys(ctx, otk)
+ MustNotError(t, err)
+
+ // Check the count of one time keys is correct
+ count, err := db.OneTimeKeysCount(ctx, userID, deviceID)
+ MustNotError(t, err)
+ if count.KeyCount["curve25519"] != 1 {
+ t.Fatalf("Expected 1 key, got %d", count.KeyCount["curve25519"])
+ }
+
+ // Check the actual key contents are correct
+ keysJSON, err := db.ExistingOneTimeKeys(ctx, userID, deviceID, []string{"curve25519:KEY1"})
+ MustNotError(t, err)
+ keyJSON, err := keysJSON["curve25519:KEY1"].MarshalJSON()
+ MustNotError(t, err)
+ if !bytes.Equal(keyJSON, []byte(`{"key":"v1"}`)) {
+ t.Fatalf("Existing keys do not match expected. Got %v", keysJSON["curve25519:KEY1"])
+ }
+
+ // Claim a one time key from the database. This should remove it from the database.
+ claimedKeys, err := db.ClaimKeys(ctx, map[string]map[string]string{userID: {deviceID: "curve25519"}})
+ MustNotError(t, err)
+
+ // Check the claimed key contents are correct
+ if !reflect.DeepEqual(claimedKeys[0], otk) {
+ t.Fatalf("Expected to claim stored key %v. Got %v", otk, claimedKeys[0])
+ }
+
+ // Check the count of one time keys is now zero
+ count, err = db.OneTimeKeysCount(ctx, userID, deviceID)
+ MustNotError(t, err)
+ if count.KeyCount["curve25519"] != 0 {
+ t.Fatalf("Expected 0 keys, got %d", count.KeyCount["curve25519"])
+ }
+ })
+}