aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--keyserver/api/api.go17
-rw-r--r--keyserver/internal/internal.go36
-rw-r--r--keyserver/producers/keychange.go2
-rw-r--r--keyserver/storage/interface.go9
-rw-r--r--keyserver/storage/postgres/device_keys_table.go82
-rw-r--r--keyserver/storage/shared/storage.go29
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go78
-rw-r--r--keyserver/storage/storage_test.go82
-rw-r--r--keyserver/storage/tables/interface.go7
-rw-r--r--syncapi/consumers/keychange.go2
-rw-r--r--sytest-whitelist1
11 files changed, 263 insertions, 82 deletions
diff --git a/keyserver/api/api.go b/keyserver/api/api.go
index eb2f9e24..080d0e5f 100644
--- a/keyserver/api/api.go
+++ b/keyserver/api/api.go
@@ -43,6 +43,13 @@ func (k *KeyError) Error() string {
return k.Err
}
+// DeviceMessage represents the message produced into Kafka by the key server.
+type DeviceMessage struct {
+ DeviceKeys
+ // A monotonically increasing number which represents device changes for this user.
+ StreamID int
+}
+
// DeviceKeys represents a set of device keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type DeviceKeys struct {
@@ -50,10 +57,20 @@ type DeviceKeys struct {
UserID string
// The device ID of this device
DeviceID string
+ // The device display name
+ DisplayName string
// The raw device key JSON
KeyJSON []byte
}
+// WithStreamID returns a copy of this device message with the given stream ID
+func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage {
+ return DeviceMessage{
+ DeviceKeys: *k,
+ StreamID: streamID,
+ }
+}
+
// OneTimeKeys represents a set of one-time keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type OneTimeKeys struct {
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index 3c8dff84..9027cbf4 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -61,7 +61,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
res.KeyErrors = make(map[string]map[string]*api.KeyError)
- a.uploadDeviceKeys(ctx, req, res)
+ a.uploadLocalDeviceKeys(ctx, req, res)
a.uploadOneTimeKeys(ctx, req, res)
}
@@ -286,18 +286,25 @@ func (a *KeyInternalAPI) queryRemoteKeys(
}
}
-func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
- var keysToStore []api.DeviceKeys
+func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
+ var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key
for _, key := range req.DeviceKeys {
+ _, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
+ if err != nil {
+ continue // ignore invalid users
+ }
+ if serverName != a.ThisServer {
+ continue // ignore remote users
+ }
if len(key.KeyJSON) == 0 {
- keysToStore = append(keysToStore, key)
+ keysToStore = append(keysToStore, key.WithStreamID(0))
continue // deleted keys don't need sanity checking
}
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
- keysToStore = append(keysToStore, key)
+ keysToStore = append(keysToStore, key.WithStreamID(0))
continue
}
@@ -310,11 +317,13 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
}
// get existing device keys so we can check for changes
- existingKeys := make([]api.DeviceKeys, len(keysToStore))
+ existingKeys := make([]api.DeviceMessage, len(keysToStore))
for i := range keysToStore {
- existingKeys[i] = api.DeviceKeys{
- UserID: keysToStore[i].UserID,
- DeviceID: keysToStore[i].DeviceID,
+ existingKeys[i] = api.DeviceMessage{
+ DeviceKeys: api.DeviceKeys{
+ UserID: keysToStore[i].UserID,
+ DeviceID: keysToStore[i].DeviceID,
+ },
}
}
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
@@ -324,13 +333,14 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
return
}
// store the device keys and emit changes
- if err := a.DB.StoreDeviceKeys(ctx, keysToStore); err != nil {
+ err := a.DB.StoreDeviceKeys(ctx, keysToStore)
+ if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
}
return
}
- err := a.emitDeviceKeyChanges(existingKeys, keysToStore)
+ err = a.emitDeviceKeyChanges(existingKeys, keysToStore)
if err != nil {
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
}
@@ -375,9 +385,9 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
}
-func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) error {
+func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) error {
// find keys in new that are not in existing
- var keysAdded []api.DeviceKeys
+ var keysAdded []api.DeviceMessage
for _, newKey := range new {
exists := false
for _, existingKey := range existing {
diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go
index 6035b67b..99629b42 100644
--- a/keyserver/producers/keychange.go
+++ b/keyserver/producers/keychange.go
@@ -41,7 +41,7 @@ func (p *KeyChange) DefaultPartition() int32 {
}
// ProduceKeyChanges creates new change events for each key
-func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error {
+func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
for _, key := range keys {
var m sarama.ProducerMessage
diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go
index 0e0158e5..11284d86 100644
--- a/keyserver/storage/interface.go
+++ b/keyserver/storage/interface.go
@@ -32,17 +32,18 @@ type Database interface {
// OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
- // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced.
- DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
+ // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
+ DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device).
+ // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
// Returns an error if there was a problem storing the keys.
- StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
+ StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
- DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
+ DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go
index d915246c..e1b4e947 100644
--- a/keyserver/storage/postgres/device_keys_table.go
+++ b/keyserver/storage/postgres/device_keys_table.go
@@ -20,7 +20,6 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@@ -32,28 +31,37 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
device_id TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
+ -- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
+ -- This means we do not store an unbounded append-only log of device keys, which is not actually
+ -- required in the spec because in the event of a missed update the server fetches the entire
+ -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
+ stream_id BIGINT NOT NULL,
-- Clobber based on tuple of user/device.
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
);
`
const upsertDeviceKeysSQL = "" +
- "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" +
- " VALUES ($1, $2, $3, $4)" +
+ "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
+ " VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
- " DO UPDATE SET key_json = $4"
+ " DO UPDATE SET key_json = $4, stream_id = $5"
const selectDeviceKeysSQL = "" +
- "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+ "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
- "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1"
+ "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
+
+const selectMaxStreamForUserSQL = "" +
+ "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
- db *sql.DB
- upsertDeviceKeysStmt *sql.Stmt
- selectDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysStmt *sql.Stmt
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
}
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@@ -73,38 +81,54 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
+ if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
+ return nil, err
+ }
return s, nil
}
-func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
+func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
- err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr)
+ var streamID int
+ err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr)
+ keys[i].StreamID = streamID
}
return nil
}
-func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
- now := time.Now().Unix()
- return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
- for _, key := range keys {
- _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
- ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON),
- )
- if err != nil {
- return err
- }
+func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
+ // nullable if there are no results
+ var nullStream sql.NullInt32
+ err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
+ if err == sql.ErrNoRows {
+ err = nil
+ }
+ if nullStream.Valid {
+ streamID = nullStream.Int32
+ }
+ return
+}
+
+func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
+ for _, key := range keys {
+ now := time.Now().Unix()
+ _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
+ ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
+ )
+ if err != nil {
+ return err
}
- return nil
- })
+ }
+ return nil
}
-func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
@@ -114,15 +138,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
for _, d := range deviceIDs {
deviceIDMap[d] = true
}
- var result []api.DeviceKeys
+ var result []api.DeviceMessage
for rows.Next() {
- var dk api.DeviceKeys
+ var dk api.DeviceMessage
dk.UserID = userID
var keyJSON string
- if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil {
+ var streamID int
+ if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
+ dk.StreamID = streamID
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)
diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go
index 44cb0cc2..e78ee943 100644
--- a/keyserver/storage/shared/storage.go
+++ b/keyserver/storage/shared/storage.go
@@ -43,15 +43,36 @@ func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
}
-func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
+func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}
-func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
- return d.DeviceKeysTable.InsertDeviceKeys(ctx, keys)
+func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
+ // work out the latest stream IDs for each user
+ userIDToStreamID := make(map[string]int)
+ for _, k := range keys {
+ userIDToStreamID[k.UserID] = 0
+ }
+ return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
+ for userID := range userIDToStreamID {
+ streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
+ if err != nil {
+ return err
+ }
+ userIDToStreamID[userID] = int(streamID)
+ }
+ // set the stream IDs for each key
+ for i := range keys {
+ k := keys[i]
+ userIDToStreamID[k.UserID]++ // start stream from 1
+ k.StreamID = userIDToStreamID[k.UserID]
+ keys[i] = k
+ }
+ return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
+ })
}
-func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
+func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
}
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
index 69fe7a6e..9f70885a 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/keyserver/storage/sqlite3/device_keys_table.go
@@ -20,7 +20,6 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@@ -32,28 +31,33 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
device_id TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
+ stream_id BIGINT NOT NULL,
-- Clobber based on tuple of user/device.
UNIQUE (user_id, device_id)
);
`
const upsertDeviceKeysSQL = "" +
- "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" +
- " VALUES ($1, $2, $3, $4)" +
+ "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
+ " VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT (user_id, device_id)" +
- " DO UPDATE SET key_json = $4"
+ " DO UPDATE SET key_json = $4, stream_id = $5"
const selectDeviceKeysSQL = "" +
- "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+ "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
- "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1"
+ "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
+
+const selectMaxStreamForUserSQL = "" +
+ "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
- db *sql.DB
- upsertDeviceKeysStmt *sql.Stmt
- selectDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysStmt *sql.Stmt
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
}
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@@ -73,10 +77,13 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
+ if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
+ return nil, err
+ }
return s, nil
}
-func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {
deviceIDMap[d] = true
@@ -86,15 +93,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
- var result []api.DeviceKeys
+ var result []api.DeviceMessage
for rows.Next() {
- var dk api.DeviceKeys
+ var dk api.DeviceMessage
dk.UserID = userID
var keyJSON string
- if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil {
+ var streamID int
+ if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
+ dk.StreamID = streamID
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)
@@ -103,30 +112,43 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
return result, rows.Err()
}
-func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
+func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
- err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr)
+ var streamID int
+ err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr)
+ keys[i].StreamID = streamID
}
return nil
}
-func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
- now := time.Now().Unix()
- return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
- for _, key := range keys {
- _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
- ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON),
- )
- if err != nil {
- return err
- }
+func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
+ // nullable if there are no results
+ var nullStream sql.NullInt32
+ err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
+ if err == sql.ErrNoRows {
+ err = nil
+ }
+ if nullStream.Valid {
+ streamID = nullStream.Int32
+ }
+ return
+}
+
+func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
+ for _, key := range keys {
+ now := time.Now().Unix()
+ _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
+ ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
+ )
+ if err != nil {
+ return err
}
- return nil
- })
+ }
+ return nil
}
diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go
index 66f6930f..b3e45e6c 100644
--- a/keyserver/storage/storage_test.go
+++ b/keyserver/storage/storage_test.go
@@ -6,6 +6,7 @@ import (
"testing"
"github.com/Shopify/sarama"
+ "github.com/matrix-org/dendrite/keyserver/api"
)
var ctx = context.Background()
@@ -77,3 +78,84 @@ func TestKeyChangesUpperLimit(t *testing.T) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
}
+
+// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
+// and that they are returned correctly when querying for device keys.
+func TestDeviceKeysStreamIDGeneration(t *testing.T) {
+ db, err := NewDatabase("file::memory:", nil)
+ if err != nil {
+ t.Fatalf("Failed to NewDatabase: %s", err)
+ }
+ alice := "@alice:TestDeviceKeysStreamIDGeneration"
+ bob := "@bob:TestDeviceKeysStreamIDGeneration"
+ msgs := []api.DeviceMessage{
+ {
+ DeviceKeys: api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 1
+ },
+ {
+ DeviceKeys: api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: bob,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 1 as this is a different user
+ },
+ {
+ DeviceKeys: api.DeviceKeys{
+ DeviceID: "another_device",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 2 as this is a 2nd device key
+ },
+ }
+ MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
+ if msgs[0].StreamID != 1 {
+ t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
+ }
+ if msgs[1].StreamID != 1 {
+ t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
+ }
+ if msgs[2].StreamID != 2 {
+ t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
+ }
+
+ // updating a device sets the next stream ID for that user
+ msgs = []api.DeviceMessage{
+ {
+ DeviceKeys: api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v2"}`),
+ },
+ // StreamID: 3
+ },
+ }
+ MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
+ if msgs[0].StreamID != 3 {
+ t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
+ }
+
+ // Querying for device keys returns the latest stream IDs
+ msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
+ if err != nil {
+ t.Fatalf("DeviceKeysForUser returned error: %s", err)
+ }
+ wantStreamIDs := map[string]int{
+ "AAA": 3,
+ "another_device": 2,
+ }
+ if len(msgs) != len(wantStreamIDs) {
+ t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
+ }
+ for _, m := range msgs {
+ if m.StreamID != wantStreamIDs[m.DeviceID] {
+ t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
+ }
+ }
+}
diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go
index c6e43be4..65da3310 100644
--- a/keyserver/storage/tables/interface.go
+++ b/keyserver/storage/tables/interface.go
@@ -32,9 +32,10 @@ type OneTimeKeys interface {
}
type DeviceKeys interface {
- SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
- InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
- SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
+ SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
+ InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
+ SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
+ SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
}
type KeyChanges interface {
diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go
index 35978be7..e14d2223 100644
--- a/syncapi/consumers/keychange.go
+++ b/syncapi/consumers/keychange.go
@@ -98,7 +98,7 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
defer func() {
s.updateOffset(msg)
}()
- var output api.DeviceKeys
+ var output api.DeviceMessage
if err := json.Unmarshal(msg.Value, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Error("syncapi: failed to unmarshal key change event from key server")
diff --git a/sytest-whitelist b/sytest-whitelist
index 16a71c64..a1d2e437 100644
--- a/sytest-whitelist
+++ b/sytest-whitelist
@@ -110,6 +110,7 @@ Rooms a user is invited to appear in an incremental sync
Sync can be polled for updates
Sync is woken up for leaves
Newly left rooms appear in the leave section of incremental sync
+Rooms can be created with an initial invite list (SYN-205)
We should see our own leave event, even if history_visibility is restricted (SYN-662)
We should see our own leave event when rejecting an invite, even if history_visibility is restricted (riot-web/3462)
Newly left rooms appear in the leave section of gapped sync