diff options
author | Kegsay <kegan@matrix.org> | 2020-08-03 17:07:06 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-03 17:07:06 +0100 |
commit | fb56bbf0b7d4b21da3f55b066e71d24bf4599887 (patch) | |
tree | 707e8051385e92c176f2c99cb3d03607724b8309 /keyserver | |
parent | ffcb6d2ea199cfa985e72ffbdcb884fb9bc9f54d (diff) |
Generate stream IDs for locally uploaded device keys (#1236)
* Breaking: add stream_id to keyserver_device_keys table
* Add tests for stream ID generation
* Fix whitelist
Diffstat (limited to 'keyserver')
-rw-r--r-- | keyserver/api/api.go | 17 | ||||
-rw-r--r-- | keyserver/internal/internal.go | 36 | ||||
-rw-r--r-- | keyserver/producers/keychange.go | 2 | ||||
-rw-r--r-- | keyserver/storage/interface.go | 9 | ||||
-rw-r--r-- | keyserver/storage/postgres/device_keys_table.go | 82 | ||||
-rw-r--r-- | keyserver/storage/shared/storage.go | 29 | ||||
-rw-r--r-- | keyserver/storage/sqlite3/device_keys_table.go | 78 | ||||
-rw-r--r-- | keyserver/storage/storage_test.go | 82 | ||||
-rw-r--r-- | keyserver/storage/tables/interface.go | 7 |
9 files changed, 261 insertions, 81 deletions
diff --git a/keyserver/api/api.go b/keyserver/api/api.go index eb2f9e24..080d0e5f 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -43,6 +43,13 @@ func (k *KeyError) Error() string { return k.Err } +// DeviceMessage represents the message produced into Kafka by the key server. +type DeviceMessage struct { + DeviceKeys + // A monotonically increasing number which represents device changes for this user. + StreamID int +} + // DeviceKeys represents a set of device keys for a single device // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload type DeviceKeys struct { @@ -50,10 +57,20 @@ type DeviceKeys struct { UserID string // The device ID of this device DeviceID string + // The device display name + DisplayName string // The raw device key JSON KeyJSON []byte } +// WithStreamID returns a copy of this device message with the given stream ID +func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage { + return DeviceMessage{ + DeviceKeys: *k, + StreamID: streamID, + } +} + // OneTimeKeys represents a set of one-time keys for a single device // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload type OneTimeKeys struct { diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 3c8dff84..9027cbf4 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -61,7 +61,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { res.KeyErrors = make(map[string]map[string]*api.KeyError) - a.uploadDeviceKeys(ctx, req, res) + a.uploadLocalDeviceKeys(ctx, req, res) a.uploadOneTimeKeys(ctx, req, res) } @@ -286,18 +286,25 @@ func (a *KeyInternalAPI) queryRemoteKeys( } } -func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { - var keysToStore []api.DeviceKeys +func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + var keysToStore []api.DeviceMessage // assert that the user ID / device ID are not lying for each key for _, key := range req.DeviceKeys { + _, serverName, err := gomatrixserverlib.SplitID('@', key.UserID) + if err != nil { + continue // ignore invalid users + } + if serverName != a.ThisServer { + continue // ignore remote users + } if len(key.KeyJSON) == 0 { - keysToStore = append(keysToStore, key) + keysToStore = append(keysToStore, key.WithStreamID(0)) continue // deleted keys don't need sanity checking } gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str if gotUserID == key.UserID && gotDeviceID == key.DeviceID { - keysToStore = append(keysToStore, key) + keysToStore = append(keysToStore, key.WithStreamID(0)) continue } @@ -310,11 +317,13 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU } // get existing device keys so we can check for changes - existingKeys := make([]api.DeviceKeys, len(keysToStore)) + existingKeys := make([]api.DeviceMessage, len(keysToStore)) for i := range keysToStore { - existingKeys[i] = api.DeviceKeys{ - UserID: keysToStore[i].UserID, - DeviceID: keysToStore[i].DeviceID, + existingKeys[i] = api.DeviceMessage{ + DeviceKeys: api.DeviceKeys{ + UserID: keysToStore[i].UserID, + DeviceID: keysToStore[i].DeviceID, + }, } } if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil { @@ -324,13 +333,14 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU return } // store the device keys and emit changes - if err := a.DB.StoreDeviceKeys(ctx, keysToStore); err != nil { + err := a.DB.StoreDeviceKeys(ctx, keysToStore) + if err != nil { res.Error = &api.KeyError{ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), } return } - err := a.emitDeviceKeyChanges(existingKeys, keysToStore) + err = a.emitDeviceKeyChanges(existingKeys, keysToStore) if err != nil { util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) } @@ -375,9 +385,9 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform } -func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) error { +func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) error { // find keys in new that are not in existing - var keysAdded []api.DeviceKeys + var keysAdded []api.DeviceMessage for _, newKey := range new { exists := false for _, existingKey := range existing { diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go index 6035b67b..99629b42 100644 --- a/keyserver/producers/keychange.go +++ b/keyserver/producers/keychange.go @@ -41,7 +41,7 @@ func (p *KeyChange) DefaultPartition() int32 { } // ProduceKeyChanges creates new change events for each key -func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error { +func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { for _, key := range keys { var m sarama.ProducerMessage diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 0e0158e5..11284d86 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -32,17 +32,18 @@ type Database interface { // OneTimeKeysCount returns a count of all OTKs for this device. OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced. - DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error // StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // for this (user, device). + // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. // Returns an error if there was a problem storing the keys. - StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error + StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. - DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index d915246c..e1b4e947 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -20,7 +20,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -32,28 +31,37 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( device_id TEXT NOT NULL, ts_added_secs BIGINT NOT NULL, key_json TEXT NOT NULL, + -- the stream ID of this key, scoped per-user. This gets updated when the device key changes. + -- This means we do not store an unbounded append-only log of device keys, which is not actually + -- required in the spec because in the event of a missed update the server fetches the entire + -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. + stream_id BIGINT NOT NULL, -- Clobber based on tuple of user/device. CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" + - " VALUES ($1, $2, $3, $4)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + + " VALUES ($1, $2, $3, $4, $5)" + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + - " DO UPDATE SET key_json = $4" + " DO UPDATE SET key_json = $4, stream_id = $5" const selectDeviceKeysSQL = "" + - "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt } func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -73,38 +81,54 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { + return nil, err + } return s, nil } -func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { for i, key := range keys { var keyJSONStr string - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr) + var streamID int + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID } return nil } -func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { - now := time.Now().Unix() - return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { - for _, key := range keys { - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), - ) - if err != nil { - return err - } +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { + // nullable if there are no results + var nullStream sql.NullInt32 + err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int32 + } + return +} + +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ) + if err != nil { + return err } - return nil - }) + } + return nil } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) if err != nil { return nil, err @@ -114,15 +138,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID for _, d := range deviceIDs { deviceIDMap[d] = true } - var result []api.DeviceKeys + var result []api.DeviceMessage for rows.Next() { - var dk api.DeviceKeys + var dk api.DeviceMessage dk.UserID = userID var keyJSON string - if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { + var streamID int + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) + dk.StreamID = streamID // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 44cb0cc2..e78ee943 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -43,15 +43,36 @@ func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) } -func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { +func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } -func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { - return d.DeviceKeysTable.InsertDeviceKeys(ctx, keys) +func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + // work out the latest stream IDs for each user + userIDToStreamID := make(map[string]int) + for _, k := range keys { + userIDToStreamID[k.UserID] = 0 + } + return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + for userID := range userIDToStreamID { + streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) + if err != nil { + return err + } + userIDToStreamID[userID] = int(streamID) + } + // set the stream IDs for each key + for i := range keys { + k := keys[i] + userIDToStreamID[k.UserID]++ // start stream from 1 + k.StreamID = userIDToStreamID[k.UserID] + keys[i] = k + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) } -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { +func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) } diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 69fe7a6e..9f70885a 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -20,7 +20,6 @@ import ( "time" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" ) @@ -32,28 +31,33 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( device_id TEXT NOT NULL, ts_added_secs BIGINT NOT NULL, key_json TEXT NOT NULL, + stream_id BIGINT NOT NULL, -- Clobber based on tuple of user/device. UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" + - " VALUES ($1, $2, $3, $4)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + + " VALUES ($1, $2, $3, $4, $5)" + " ON CONFLICT (user_id, device_id)" + - " DO UPDATE SET key_json = $4" + " DO UPDATE SET key_json = $4, stream_id = $5" const selectDeviceKeysSQL = "" + - "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt } func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -73,10 +77,13 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { return nil, err } + if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { + return nil, err + } return s, nil } -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) { deviceIDMap := make(map[string]bool) for _, d := range deviceIDs { deviceIDMap[d] = true @@ -86,15 +93,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") - var result []api.DeviceKeys + var result []api.DeviceMessage for rows.Next() { - var dk api.DeviceKeys + var dk api.DeviceMessage dk.UserID = userID var keyJSON string - if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { + var streamID int + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) + dk.StreamID = streamID // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) @@ -103,30 +112,43 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID return result, rows.Err() } -func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { for i, key := range keys { var keyJSONStr string - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr) + var streamID int + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID } return nil } -func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { - now := time.Now().Unix() - return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { - for _, key := range keys { - _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), - ) - if err != nil { - return err - } +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) { + // nullable if there are no results + var nullStream sql.NullInt32 + err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int32 + } + return +} + +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ) + if err != nil { + return err } - return nil - }) + } + return nil } diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 66f6930f..b3e45e6c 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/Shopify/sarama" + "github.com/matrix-org/dendrite/keyserver/api" ) var ctx = context.Background() @@ -77,3 +78,84 @@ func TestKeyChangesUpperLimit(t *testing.T) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) } } + +// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, +// and that they are returned correctly when querying for device keys. +func TestDeviceKeysStreamIDGeneration(t *testing.T) { + db, err := NewDatabase("file::memory:", nil) + if err != nil { + t.Fatalf("Failed to NewDatabase: %s", err) + } + alice := "@alice:TestDeviceKeysStreamIDGeneration" + bob := "@bob:TestDeviceKeysStreamIDGeneration" + msgs := []api.DeviceMessage{ + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 + }, + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: bob, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 as this is a different user + }, + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "another_device", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 2 as this is a 2nd device key + }, + } + MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 1 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + } + if msgs[1].StreamID != 1 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + } + if msgs[2].StreamID != 2 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + } + + // updating a device sets the next stream ID for that user + msgs = []api.DeviceMessage{ + { + DeviceKeys: api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v2"}`), + }, + // StreamID: 3 + }, + } + MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 3 { + t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + } + + // Querying for device keys returns the latest stream IDs + msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}) + if err != nil { + t.Fatalf("DeviceKeysForUser returned error: %s", err) + } + wantStreamIDs := map[string]int{ + "AAA": 3, + "another_device": 2, + } + if len(msgs) != len(wantStreamIDs) { + t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) + } + for _, m := range msgs { + if m.StreamID != wantStreamIDs[m.DeviceID] { + t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) + } + } +} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index c6e43be4..65da3310 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -32,9 +32,10 @@ type OneTimeKeys interface { } type DeviceKeys interface { - SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error - InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error - SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) + SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error + SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) } type KeyChanges interface { |