diff options
author | Kegsay <kegan@matrix.org> | 2020-08-05 13:41:16 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-05 13:41:16 +0100 |
commit | 642f9cb964b20f52133e11c52e40733f7bc07320 (patch) | |
tree | e48f667d681161a9694b45cb08feded569e539b4 /keyserver/storage | |
parent | 15dc1f4d0361da736339653ca8e6ba26ed103792 (diff) |
Process inbound device list updates from federation (#1240)
* Add InputDeviceListUpdate
* Unbreak unit tests
* Process inbound device list updates from federation
- Persist the keys in the keyserver and produce key changes
- Does not currently fetch keys from the remote server if the prev IDs are missing
* Linting
Diffstat (limited to 'keyserver/storage')
-rw-r--r-- | keyserver/storage/interface.go | 11 | ||||
-rw-r--r-- | keyserver/storage/postgres/device_keys_table.go | 21 | ||||
-rw-r--r-- | keyserver/storage/shared/storage.go | 20 | ||||
-rw-r--r-- | keyserver/storage/sqlite3/device_keys_table.go | 23 | ||||
-rw-r--r-- | keyserver/storage/storage_test.go | 12 | ||||
-rw-r--r-- | keyserver/storage/tables/interface.go | 1 |
6 files changed, 79 insertions, 9 deletions
diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 11284d86..f67bbf71 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -35,11 +35,18 @@ type Database interface { // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - // StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // for this (user, device). // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. // Returns an error if there was a problem storing the keys. - StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). Does not modify the stream ID for keys. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + + // PrevIDsExists returns true if all prev IDs exist for this user. + PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index e1b4e947..d321860d 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "time" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" @@ -56,12 +57,16 @@ const selectBatchDeviceKeysSQL = "" + const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" +const countStreamIDsForUserSQL = "" + + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)" + type deviceKeysStatements struct { db *sql.DB upsertDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt selectMaxStreamForUserStmt *sql.Stmt + countStreamIDsForUserStmt *sql.Stmt } func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { @@ -84,6 +89,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { return nil, err } + if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil { + return nil, err + } return s, nil } @@ -115,6 +123,19 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn return } +func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { + // nullable if there are no results + var count sql.NullInt32 + err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count) + if err != nil { + return 0, err + } + if count.Valid { + return int(count.Int32), nil + } + return 0, nil +} + func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { for _, key := range keys { now := time.Now().Unix() diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index e78ee943..78729774 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -47,7 +47,25 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) } -func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { +func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) { + sids := make([]int64, len(prevIDs)) + for i := range prevIDs { + sids[i] = int64(prevIDs[i]) + } + count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, sids) + if err != nil { + return false, err + } + return count == len(prevIDs), nil +} + +func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { // work out the latest stream IDs for each user userIDToStreamID := make(map[string]int) for _, k := range keys { diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 900d1238..15d9c775 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -17,6 +17,7 @@ package sqlite3 import ( "context" "database/sql" + "strings" "time" "github.com/matrix-org/dendrite/internal" @@ -53,6 +54,9 @@ const selectBatchDeviceKeysSQL = "" + const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" +const countStreamIDsForUserSQL = "" + + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" + type deviceKeysStatements struct { db *sql.DB writer *sqlutil.TransactionWriter @@ -143,6 +147,25 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn return } +func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { + iStreamIDs := make([]interface{}, len(streamIDs)+1) + iStreamIDs[0] = userID + for i := range streamIDs { + iStreamIDs[i+1] = streamIDs[i] + } + query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) + // nullable if there are no results + var count sql.NullInt32 + err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) + if err != nil { + return 0, err + } + if count.Valid { + return int(count.Int32), nil + } + return 0, nil +} + func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { for _, key := range keys { diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 949d9dd6..ec1b299f 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -126,15 +126,15 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { // StreamID: 2 as this is a 2nd device key }, } - MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) if msgs[0].StreamID != 1 { - t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) } if msgs[1].StreamID != 1 { - t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) } if msgs[2].StreamID != 2 { - t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) } // updating a device sets the next stream ID for that user @@ -148,9 +148,9 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { // StreamID: 3 }, } - MustNotError(t, db.StoreDeviceKeys(ctx, msgs)) + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) if msgs[0].StreamID != 3 { - t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) } // Querying for device keys returns the latest stream IDs diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 65da3310..ac932d56 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -35,6 +35,7 @@ type DeviceKeys interface { SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) + CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) } |