aboutsummaryrefslogtreecommitdiff
path: root/keyserver
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-08-03 17:07:06 +0100
committerGitHub <noreply@github.com>2020-08-03 17:07:06 +0100
commitfb56bbf0b7d4b21da3f55b066e71d24bf4599887 (patch)
tree707e8051385e92c176f2c99cb3d03607724b8309 /keyserver
parentffcb6d2ea199cfa985e72ffbdcb884fb9bc9f54d (diff)
Generate stream IDs for locally uploaded device keys (#1236)
* Breaking: add stream_id to keyserver_device_keys table * Add tests for stream ID generation * Fix whitelist
Diffstat (limited to 'keyserver')
-rw-r--r--keyserver/api/api.go17
-rw-r--r--keyserver/internal/internal.go36
-rw-r--r--keyserver/producers/keychange.go2
-rw-r--r--keyserver/storage/interface.go9
-rw-r--r--keyserver/storage/postgres/device_keys_table.go82
-rw-r--r--keyserver/storage/shared/storage.go29
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go78
-rw-r--r--keyserver/storage/storage_test.go82
-rw-r--r--keyserver/storage/tables/interface.go7
9 files changed, 261 insertions, 81 deletions
diff --git a/keyserver/api/api.go b/keyserver/api/api.go
index eb2f9e24..080d0e5f 100644
--- a/keyserver/api/api.go
+++ b/keyserver/api/api.go
@@ -43,6 +43,13 @@ func (k *KeyError) Error() string {
return k.Err
}
+// DeviceMessage represents the message produced into Kafka by the key server.
+type DeviceMessage struct {
+ DeviceKeys
+ // A monotonically increasing number which represents device changes for this user.
+ StreamID int
+}
+
// DeviceKeys represents a set of device keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type DeviceKeys struct {
@@ -50,10 +57,20 @@ type DeviceKeys struct {
UserID string
// The device ID of this device
DeviceID string
+ // The device display name
+ DisplayName string
// The raw device key JSON
KeyJSON []byte
}
+// WithStreamID returns a copy of this device message with the given stream ID
+func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage {
+ return DeviceMessage{
+ DeviceKeys: *k,
+ StreamID: streamID,
+ }
+}
+
// OneTimeKeys represents a set of one-time keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type OneTimeKeys struct {
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index 3c8dff84..9027cbf4 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -61,7 +61,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
res.KeyErrors = make(map[string]map[string]*api.KeyError)
- a.uploadDeviceKeys(ctx, req, res)
+ a.uploadLocalDeviceKeys(ctx, req, res)
a.uploadOneTimeKeys(ctx, req, res)
}
@@ -286,18 +286,25 @@ func (a *KeyInternalAPI) queryRemoteKeys(
}
}
-func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
- var keysToStore []api.DeviceKeys
+func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
+ var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key
for _, key := range req.DeviceKeys {
+ _, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
+ if err != nil {
+ continue // ignore invalid users
+ }
+ if serverName != a.ThisServer {
+ continue // ignore remote users
+ }
if len(key.KeyJSON) == 0 {
- keysToStore = append(keysToStore, key)
+ keysToStore = append(keysToStore, key.WithStreamID(0))
continue // deleted keys don't need sanity checking
}
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
- keysToStore = append(keysToStore, key)
+ keysToStore = append(keysToStore, key.WithStreamID(0))
continue
}
@@ -310,11 +317,13 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
}
// get existing device keys so we can check for changes
- existingKeys := make([]api.DeviceKeys, len(keysToStore))
+ existingKeys := make([]api.DeviceMessage, len(keysToStore))
for i := range keysToStore {
- existingKeys[i] = api.DeviceKeys{
- UserID: keysToStore[i].UserID,
- DeviceID: keysToStore[i].DeviceID,
+ existingKeys[i] = api.DeviceMessage{
+ DeviceKeys: api.DeviceKeys{
+ UserID: keysToStore[i].UserID,
+ DeviceID: keysToStore[i].DeviceID,
+ },
}
}
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
@@ -324,13 +333,14 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
return
}
// store the device keys and emit changes
- if err := a.DB.StoreDeviceKeys(ctx, keysToStore); err != nil {
+ err := a.DB.StoreDeviceKeys(ctx, keysToStore)
+ if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
}
return
}
- err := a.emitDeviceKeyChanges(existingKeys, keysToStore)
+ err = a.emitDeviceKeyChanges(existingKeys, keysToStore)
if err != nil {
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
}
@@ -375,9 +385,9 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
}
-func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) error {
+func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) error {
// find keys in new that are not in existing
- var keysAdded []api.DeviceKeys
+ var keysAdded []api.DeviceMessage
for _, newKey := range new {
exists := false
for _, existingKey := range existing {
diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go
index 6035b67b..99629b42 100644
--- a/keyserver/producers/keychange.go
+++ b/keyserver/producers/keychange.go
@@ -41,7 +41,7 @@ func (p *KeyChange) DefaultPartition() int32 {
}
// ProduceKeyChanges creates new change events for each key
-func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error {
+func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
for _, key := range keys {
var m sarama.ProducerMessage
diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go
index 0e0158e5..11284d86 100644
--- a/keyserver/storage/interface.go
+++ b/keyserver/storage/interface.go
@@ -32,17 +32,18 @@ type Database interface {
// OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
- // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced.
- DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
+ // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
+ DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device).
+ // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
// Returns an error if there was a problem storing the keys.
- StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
+ StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
- DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
+ DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go
index d915246c..e1b4e947 100644
--- a/keyserver/storage/postgres/device_keys_table.go
+++ b/keyserver/storage/postgres/device_keys_table.go
@@ -20,7 +20,6 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@@ -32,28 +31,37 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
device_id TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
+ -- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
+ -- This means we do not store an unbounded append-only log of device keys, which is not actually
+ -- required in the spec because in the event of a missed update the server fetches the entire
+ -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
+ stream_id BIGINT NOT NULL,
-- Clobber based on tuple of user/device.
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
);
`
const upsertDeviceKeysSQL = "" +
- "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" +
- " VALUES ($1, $2, $3, $4)" +
+ "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
+ " VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
- " DO UPDATE SET key_json = $4"
+ " DO UPDATE SET key_json = $4, stream_id = $5"
const selectDeviceKeysSQL = "" +
- "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+ "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
- "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1"
+ "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
+
+const selectMaxStreamForUserSQL = "" +
+ "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
- db *sql.DB
- upsertDeviceKeysStmt *sql.Stmt
- selectDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysStmt *sql.Stmt
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
}
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@@ -73,38 +81,54 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
+ if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
+ return nil, err
+ }
return s, nil
}
-func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
+func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
- err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr)
+ var streamID int
+ err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr)
+ keys[i].StreamID = streamID
}
return nil
}
-func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
- now := time.Now().Unix()
- return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
- for _, key := range keys {
- _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
- ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON),
- )
- if err != nil {
- return err
- }
+func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
+ // nullable if there are no results
+ var nullStream sql.NullInt32
+ err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
+ if err == sql.ErrNoRows {
+ err = nil
+ }
+ if nullStream.Valid {
+ streamID = nullStream.Int32
+ }
+ return
+}
+
+func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
+ for _, key := range keys {
+ now := time.Now().Unix()
+ _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
+ ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
+ )
+ if err != nil {
+ return err
}
- return nil
- })
+ }
+ return nil
}
-func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
@@ -114,15 +138,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
for _, d := range deviceIDs {
deviceIDMap[d] = true
}
- var result []api.DeviceKeys
+ var result []api.DeviceMessage
for rows.Next() {
- var dk api.DeviceKeys
+ var dk api.DeviceMessage
dk.UserID = userID
var keyJSON string
- if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil {
+ var streamID int
+ if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
+ dk.StreamID = streamID
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)
diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go
index 44cb0cc2..e78ee943 100644
--- a/keyserver/storage/shared/storage.go
+++ b/keyserver/storage/shared/storage.go
@@ -43,15 +43,36 @@ func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
}
-func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
+func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}
-func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
- return d.DeviceKeysTable.InsertDeviceKeys(ctx, keys)
+func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
+ // work out the latest stream IDs for each user
+ userIDToStreamID := make(map[string]int)
+ for _, k := range keys {
+ userIDToStreamID[k.UserID] = 0
+ }
+ return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
+ for userID := range userIDToStreamID {
+ streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
+ if err != nil {
+ return err
+ }
+ userIDToStreamID[userID] = int(streamID)
+ }
+ // set the stream IDs for each key
+ for i := range keys {
+ k := keys[i]
+ userIDToStreamID[k.UserID]++ // start stream from 1
+ k.StreamID = userIDToStreamID[k.UserID]
+ keys[i] = k
+ }
+ return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
+ })
}
-func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
+func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
}
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
index 69fe7a6e..9f70885a 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/keyserver/storage/sqlite3/device_keys_table.go
@@ -20,7 +20,6 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@@ -32,28 +31,33 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
device_id TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
+ stream_id BIGINT NOT NULL,
-- Clobber based on tuple of user/device.
UNIQUE (user_id, device_id)
);
`
const upsertDeviceKeysSQL = "" +
- "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" +
- " VALUES ($1, $2, $3, $4)" +
+ "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
+ " VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT (user_id, device_id)" +
- " DO UPDATE SET key_json = $4"
+ " DO UPDATE SET key_json = $4, stream_id = $5"
const selectDeviceKeysSQL = "" +
- "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+ "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
- "SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1"
+ "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
+
+const selectMaxStreamForUserSQL = "" +
+ "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
- db *sql.DB
- upsertDeviceKeysStmt *sql.Stmt
- selectDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysStmt *sql.Stmt
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
}
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@@ -73,10 +77,13 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
+ if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
+ return nil, err
+ }
return s, nil
}
-func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {
deviceIDMap[d] = true
@@ -86,15 +93,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
- var result []api.DeviceKeys
+ var result []api.DeviceMessage
for rows.Next() {
- var dk api.DeviceKeys
+ var dk api.DeviceMessage
dk.UserID = userID
var keyJSON string
- if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil {
+ var streamID int
+ if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
+ dk.StreamID = streamID
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)
@@ -103,30 +112,43 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
return result, rows.Err()
}
-func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
+func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
- err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr)
+ var streamID int
+ err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr)
+ keys[i].StreamID = streamID
}
return nil
}
-func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error {
- now := time.Now().Unix()
- return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error {
- for _, key := range keys {
- _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
- ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON),
- )
- if err != nil {
- return err
- }
+func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
+ // nullable if there are no results
+ var nullStream sql.NullInt32
+ err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
+ if err == sql.ErrNoRows {
+ err = nil
+ }
+ if nullStream.Valid {
+ streamID = nullStream.Int32
+ }
+ return
+}
+
+func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
+ for _, key := range keys {
+ now := time.Now().Unix()
+ _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
+ ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
+ )
+ if err != nil {
+ return err
}
- return nil
- })
+ }
+ return nil
}
diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go
index 66f6930f..b3e45e6c 100644
--- a/keyserver/storage/storage_test.go
+++ b/keyserver/storage/storage_test.go
@@ -6,6 +6,7 @@ import (
"testing"
"github.com/Shopify/sarama"
+ "github.com/matrix-org/dendrite/keyserver/api"
)
var ctx = context.Background()
@@ -77,3 +78,84 @@ func TestKeyChangesUpperLimit(t *testing.T) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
}
+
+// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
+// and that they are returned correctly when querying for device keys.
+func TestDeviceKeysStreamIDGeneration(t *testing.T) {
+ db, err := NewDatabase("file::memory:", nil)
+ if err != nil {
+ t.Fatalf("Failed to NewDatabase: %s", err)
+ }
+ alice := "@alice:TestDeviceKeysStreamIDGeneration"
+ bob := "@bob:TestDeviceKeysStreamIDGeneration"
+ msgs := []api.DeviceMessage{
+ {
+ DeviceKeys: api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 1
+ },
+ {
+ DeviceKeys: api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: bob,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 1 as this is a different user
+ },
+ {
+ DeviceKeys: api.DeviceKeys{
+ DeviceID: "another_device",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 2 as this is a 2nd device key
+ },
+ }
+ MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
+ if msgs[0].StreamID != 1 {
+ t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
+ }
+ if msgs[1].StreamID != 1 {
+ t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
+ }
+ if msgs[2].StreamID != 2 {
+ t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
+ }
+
+ // updating a device sets the next stream ID for that user
+ msgs = []api.DeviceMessage{
+ {
+ DeviceKeys: api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v2"}`),
+ },
+ // StreamID: 3
+ },
+ }
+ MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
+ if msgs[0].StreamID != 3 {
+ t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
+ }
+
+ // Querying for device keys returns the latest stream IDs
+ msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
+ if err != nil {
+ t.Fatalf("DeviceKeysForUser returned error: %s", err)
+ }
+ wantStreamIDs := map[string]int{
+ "AAA": 3,
+ "another_device": 2,
+ }
+ if len(msgs) != len(wantStreamIDs) {
+ t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
+ }
+ for _, m := range msgs {
+ if m.StreamID != wantStreamIDs[m.DeviceID] {
+ t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
+ }
+ }
+}
diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go
index c6e43be4..65da3310 100644
--- a/keyserver/storage/tables/interface.go
+++ b/keyserver/storage/tables/interface.go
@@ -32,9 +32,10 @@ type OneTimeKeys interface {
}
type DeviceKeys interface {
- SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error
- InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
- SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
+ SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
+ InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
+ SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
+ SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
}
type KeyChanges interface {