aboutsummaryrefslogtreecommitdiff
path: root/keyserver
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-02-18 11:31:05 +0000
committerGitHub <noreply@github.com>2022-02-18 11:31:05 +0000
commit153bfbbea579dfa10e8e804036f17c1a33b6fe80 (patch)
treee135dcefc59618d7b86cd8687c1a2a304385ce45 /keyserver
parent0a7dea44505f703af1e7e069602ca95aa5a83700 (diff)
Merge both user API databases into one (#2186)
* Merge user API databases into one * Remove DeviceDatabase from config * Fix tests * Try that again * Clean up keyserver device keys when the devices no longer exist in the user API * Tweak ordering * Fix UserExists flag, device check * Allow including empty entries so we can clean them up * Remove logging
Diffstat (limited to 'keyserver')
-rw-r--r--keyserver/internal/internal.go82
-rw-r--r--keyserver/storage/interface.go2
-rw-r--r--keyserver/storage/postgres/device_keys_table.go33
-rw-r--r--keyserver/storage/shared/storage.go4
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go31
-rw-r--r--keyserver/storage/storage_test.go2
-rw-r--r--keyserver/storage/tables/interface.go2
7 files changed, 110 insertions, 46 deletions
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index ffbcac94..1c6b0677 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -198,7 +198,7 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne
}
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
- msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil)
+ msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
@@ -244,7 +244,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
domain := string(serverName)
// query local devices
if serverName == a.ThisServer {
- deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
+ deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query local device keys: %s", err),
@@ -525,7 +525,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
) error {
- keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
+ keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
if err != nil {
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
@@ -554,10 +554,60 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
}
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
+ // get a list of devices from the user API that actually exist, as
+ // we won't store keys for devices that don't exist
+ uapidevices := &userapi.QueryDevicesResponse{}
+ if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
+ res.Error = &api.KeyError{
+ Err: err.Error(),
+ }
+ return
+ }
+ if !uapidevices.UserExists {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("user %q does not exist", req.UserID),
+ }
+ return
+ }
+ existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
+ for _, key := range uapidevices.Devices {
+ existingDeviceMap[key.ID] = struct{}{}
+ }
+
+ // Get all of the user existing device keys so we can check for changes.
+ existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
+ }
+ return
+ }
+
+ // Work out whether we have device keys in the keyserver for devices that
+ // no longer exist in the user API. This is mostly an exercise to ensure
+ // that we keep some integrity between the two.
+ var toClean []gomatrixserverlib.KeyID
+ for _, k := range existingKeys {
+ if _, ok := existingDeviceMap[k.DeviceID]; !ok {
+ toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
+ }
+ }
+
+ if len(toClean) > 0 {
+ if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to clean device keys: %s", err.Error()),
+ }
+ return
+ }
+ logrus.WithField("user_id", req.UserID).Infof("Cleaned up %d stale keyserver device key entries", len(toClean))
+ }
+
var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key
for _, key := range req.DeviceKeys {
- _, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
+ var serverName gomatrixserverlib.ServerName
+ _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
if err != nil {
continue // ignore invalid users
}
@@ -568,6 +618,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
keysToStore = append(keysToStore, key.WithStreamID(0))
continue // deleted keys don't need sanity checking
}
+ // check that the device in question actually exists in the user
+ // API before we try and store a key for it
+ if _, ok := existingDeviceMap[key.DeviceID]; !ok {
+ continue
+ }
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
@@ -583,29 +638,12 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
})
}
- // get existing device keys so we can check for changes
- existingKeys := make([]api.DeviceMessage, len(keysToStore))
- for i := range keysToStore {
- existingKeys[i] = api.DeviceMessage{
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- UserID: keysToStore[i].UserID,
- DeviceID: keysToStore[i].DeviceID,
- },
- }
- }
- if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
- }
- return
- }
if req.OnlyDisplayNameUpdates {
// add the display name field from keysToStore into existingKeys
keysToStore = appendDisplayNames(existingKeys, keysToStore)
}
// store the device keys and emit changes
- err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
+ err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go
index 0110860e..4dffe695 100644
--- a/keyserver/storage/interface.go
+++ b/keyserver/storage/interface.go
@@ -53,7 +53,7 @@ type Database interface {
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
- DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
+ DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go
index 5ae0da96..628301cf 100644
--- a/keyserver/storage/postgres/device_keys_table.go
+++ b/keyserver/storage/postgres/device_keys_table.go
@@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" +
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
+const selectBatchDeviceKeysWithEmptiesSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
+
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@@ -69,14 +72,15 @@ const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
- db *sql.DB
- upsertDeviceKeysStmt *sql.Stmt
- selectDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysStmt *sql.Stmt
- selectMaxStreamForUserStmt *sql.Stmt
- countStreamIDsForUserStmt *sql.Stmt
- deleteDeviceKeysStmt *sql.Stmt
- deleteAllDeviceKeysStmt *sql.Stmt
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
+ countStreamIDsForUserStmt *sql.Stmt
+ deleteDeviceKeysStmt *sql.Stmt
+ deleteAllDeviceKeysStmt *sql.Stmt
}
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
+ if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
+ return nil, err
+ }
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
@@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
return err
}
-func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
- rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
+ var stmt *sql.Stmt
+ if includeEmpty {
+ stmt = s.selectBatchDeviceKeysWithEmptiesStmt
+ } else {
+ stmt = s.selectBatchDeviceKeysStmt
+ }
+ rows, err := stmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go
index 5914d28e..deee76eb 100644
--- a/keyserver/storage/shared/storage.go
+++ b/keyserver/storage/shared/storage.go
@@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
})
}
-func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
- return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
+func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
+ return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
}
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
index fa1c930d..b461424c 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/keyserver/storage/sqlite3/device_keys_table.go
@@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" +
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
+const selectBatchDeviceKeysWithEmptiesSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
+
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@@ -65,13 +68,14 @@ const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
- db *sql.DB
- upsertDeviceKeysStmt *sql.Stmt
- selectDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysStmt *sql.Stmt
- selectMaxStreamForUserStmt *sql.Stmt
- deleteDeviceKeysStmt *sql.Stmt
- deleteAllDeviceKeysStmt *sql.Stmt
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
+ deleteDeviceKeysStmt *sql.Stmt
+ deleteAllDeviceKeysStmt *sql.Stmt
}
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
+ if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
+ return nil, err
+ }
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
@@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
return err
}
-func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {
deviceIDMap[d] = true
}
- rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
+ var stmt *sql.Stmt
+ if includeEmpty {
+ stmt = s.selectBatchDeviceKeysWithEmptiesStmt
+ } else {
+ stmt = s.selectBatchDeviceKeysStmt
+ }
+ rows, err := stmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go
index c4c99d8c..4d513724 100644
--- a/keyserver/storage/storage_test.go
+++ b/keyserver/storage/storage_test.go
@@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
}
// Querying for device keys returns the latest stream IDs
- msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
+ msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go
index e44757e1..ff70a236 100644
--- a/keyserver/storage/tables/interface.go
+++ b/keyserver/storage/tables/interface.go
@@ -38,7 +38,7 @@ type DeviceKeys interface {
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
- SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
+ SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
}