aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-02-20 14:58:03 +0100
committerGitHub <noreply@github.com>2023-02-20 14:58:03 +0100
commit4594233f89f8531fca8f696ab0ece36909130c2a (patch)
tree18d3c451041423022e15ba5fcc4a778806ff94dc /userapi
parentbd6f0c14e56af71d83d703b7c91b8cf829ca560f (diff)
Merge keyserver & userapi (#2972)
As discussed yesterday, a first draft of merging the keyserver and the userapi.
Diffstat (limited to 'userapi')
-rw-r--r--userapi/api/api.go329
-rw-r--r--userapi/consumers/clientapi.go4
-rw-r--r--userapi/consumers/devicelistupdate.go95
-rw-r--r--userapi/consumers/roomserver.go4
-rw-r--r--userapi/consumers/roomserver_test.go4
-rw-r--r--userapi/consumers/signingkeyupdate.go111
-rw-r--r--userapi/internal/cross_signing.go587
-rw-r--r--userapi/internal/device_list_update.go579
-rw-r--r--userapi/internal/device_list_update_default.go22
-rw-r--r--userapi/internal/device_list_update_sytest.go25
-rw-r--r--userapi/internal/device_list_update_test.go431
-rw-r--r--userapi/internal/key_api.go798
-rw-r--r--userapi/internal/key_api_test.go161
-rw-r--r--userapi/internal/user_api.go (renamed from userapi/internal/api.go)36
-rw-r--r--userapi/producers/keychange.go107
-rw-r--r--userapi/producers/syncapi.go4
-rw-r--r--userapi/storage/interface.go76
-rw-r--r--userapi/storage/postgres/account_data_table.go8
-rw-r--r--userapi/storage/postgres/cross_signing_keys_table.go102
-rw-r--r--userapi/storage/postgres/cross_signing_sigs_table.go131
-rw-r--r--userapi/storage/postgres/deltas/2022012016470000_key_changes.go69
-rw-r--r--userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go47
-rw-r--r--userapi/storage/postgres/device_keys_table.go213
-rw-r--r--userapi/storage/postgres/devices_table.go8
-rw-r--r--userapi/storage/postgres/key_backup_table.go4
-rw-r--r--userapi/storage/postgres/key_changes_table.go127
-rw-r--r--userapi/storage/postgres/one_time_keys_table.go194
-rw-r--r--userapi/storage/postgres/stale_device_lists.go131
-rw-r--r--userapi/storage/postgres/storage.go41
-rw-r--r--userapi/storage/shared/storage.go235
-rw-r--r--userapi/storage/sqlite3/cross_signing_keys_table.go101
-rw-r--r--userapi/storage/sqlite3/cross_signing_sigs_table.go129
-rw-r--r--userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go66
-rw-r--r--userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go71
-rw-r--r--userapi/storage/sqlite3/device_keys_table.go213
-rw-r--r--userapi/storage/sqlite3/devices_table.go17
-rw-r--r--userapi/storage/sqlite3/key_backup_table.go4
-rw-r--r--userapi/storage/sqlite3/key_changes_table.go125
-rw-r--r--userapi/storage/sqlite3/one_time_keys_table.go208
-rw-r--r--userapi/storage/sqlite3/stale_device_lists.go145
-rw-r--r--userapi/storage/sqlite3/stats_table.go3
-rw-r--r--userapi/storage/sqlite3/storage.go45
-rw-r--r--userapi/storage/storage.go27
-rw-r--r--userapi/storage/storage_test.go210
-rw-r--r--userapi/storage/storage_wasm.go4
-rw-r--r--userapi/storage/tables/interface.go46
-rw-r--r--userapi/storage/tables/stale_device_lists_test.go94
-rw-r--r--userapi/types/storage.go50
-rw-r--r--userapi/userapi.go62
-rw-r--r--userapi/userapi_test.go327
-rw-r--r--userapi/util/devices.go2
-rw-r--r--userapi/util/notify.go4
-rw-r--r--userapi/util/notify_test.go2
-rw-r--r--userapi/util/phonehomestats_test.go2
54 files changed, 6550 insertions, 90 deletions
diff --git a/userapi/api/api.go b/userapi/api/api.go
index 4ea2e91c..fa297f77 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -15,9 +15,13 @@
package api
import (
+ "bytes"
"context"
"encoding/json"
+ "strings"
+ "time"
+ "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@@ -26,15 +30,12 @@ import (
// UserInternalAPI is the internal API for information about users and devices.
type UserInternalAPI interface {
- AppserviceUserAPI
SyncUserAPI
ClientUserAPI
- MediaUserAPI
FederationUserAPI
- RoomserverUserAPI
- KeyserverUserAPI
QuerySearchProfilesAPI // used by p2p demos
+ QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
}
// api functions required by the appservice api
@@ -43,11 +44,6 @@ type AppserviceUserAPI interface {
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
}
-type KeyserverUserAPI interface {
- QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
- QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
-}
-
type RoomserverUserAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
@@ -60,13 +56,20 @@ type MediaUserAPI interface {
// api functions required by the federation api
type FederationUserAPI interface {
+ UploadDeviceKeysAPI
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
+ QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
+ QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
+ QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
+ QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
+ PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
}
// api functions required by the sync api
type SyncUserAPI interface {
QueryAcccessTokenAPI
+ SyncKeyAPI
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
@@ -79,6 +82,7 @@ type ClientUserAPI interface {
QueryAcccessTokenAPI
LoginTokenInternalAPI
UserLoginAPI
+ ClientKeyAPI
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
@@ -681,3 +685,310 @@ type QueryAccountByLocalpartRequest struct {
type QueryAccountByLocalpartResponse struct {
Account *Account
}
+
+// API functions required by the clientapi
+type ClientKeyAPI interface {
+ UploadDeviceKeysAPI
+ QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
+ PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
+
+ PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error
+ // PerformClaimKeys claims one-time keys for use in pre-key messages
+ PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
+ PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
+}
+
+type UploadDeviceKeysAPI interface {
+ PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
+}
+
+// API functions required by the syncapi
+type SyncKeyAPI interface {
+ QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
+ QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
+ PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
+}
+
+type FederationKeyAPI interface {
+ UploadDeviceKeysAPI
+ QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
+ QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
+ QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
+ PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
+}
+
+// KeyError is returned if there was a problem performing/querying the server
+type KeyError struct {
+ Err string `json:"error"`
+ IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE
+ IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM
+ IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM
+}
+
+func (k *KeyError) Error() string {
+ return k.Err
+}
+
+type DeviceMessageType int
+
+const (
+ TypeDeviceKeyUpdate DeviceMessageType = iota
+ TypeCrossSigningUpdate
+)
+
+// DeviceMessage represents the message produced into Kafka by the key server.
+type DeviceMessage struct {
+ Type DeviceMessageType `json:"Type,omitempty"`
+ *DeviceKeys `json:"DeviceKeys,omitempty"`
+ *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
+ // A monotonically increasing number which represents device changes for this user.
+ StreamID int64
+ DeviceChangeID int64
+}
+
+// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log
+type OutputCrossSigningKeyUpdate struct {
+ CrossSigningKeyUpdate `json:"signing_keys"`
+}
+
+type CrossSigningKeyUpdate struct {
+ MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"`
+ SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"`
+ UserID string `json:"user_id"`
+}
+
+// DeviceKeysEqual returns true if the device keys updates contain the
+// same display name and key JSON. This will return false if either of
+// the updates is not a device keys update, or if the user ID/device ID
+// differ between the two.
+func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool {
+ if m1.DeviceKeys == nil || m2.DeviceKeys == nil {
+ return false
+ }
+ if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID {
+ return false
+ }
+ if m1.DisplayName != m2.DisplayName {
+ return false // different display names
+ }
+ if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 {
+ return false // either is empty
+ }
+ return bytes.Equal(m1.KeyJSON, m2.KeyJSON)
+}
+
+// DeviceKeys represents a set of device keys for a single device
+// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
+type DeviceKeys struct {
+ // The user who owns this device
+ UserID string
+ // The device ID of this device
+ DeviceID string
+ // The device display name
+ DisplayName string
+ // The raw device key JSON
+ KeyJSON []byte
+}
+
+// WithStreamID returns a copy of this device message with the given stream ID
+func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage {
+ return DeviceMessage{
+ DeviceKeys: k,
+ StreamID: streamID,
+ }
+}
+
+// OneTimeKeys represents a set of one-time keys for a single device
+// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
+type OneTimeKeys struct {
+ // The user who owns this device
+ UserID string
+ // The device ID of this device
+ DeviceID string
+ // A map of algorithm:key_id => key JSON
+ KeyJSON map[string]json.RawMessage
+}
+
+// Split a key in KeyJSON into algorithm and key ID
+func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
+ segments := strings.Split(keyIDWithAlgo, ":")
+ return segments[0], segments[1]
+}
+
+// OneTimeKeysCount represents the counts of one-time keys for a single device
+type OneTimeKeysCount struct {
+ // The user who owns this device
+ UserID string
+ // The device ID of this device
+ DeviceID string
+ // algorithm to count e.g:
+ // {
+ // "curve25519": 10,
+ // "signed_curve25519": 20
+ // }
+ KeyCount map[string]int
+}
+
+// PerformUploadKeysRequest is the request to PerformUploadKeys
+type PerformUploadKeysRequest struct {
+ UserID string // Required - User performing the request
+ DeviceID string // Optional - Device performing the request, for fetching OTK count
+ DeviceKeys []DeviceKeys
+ OneTimeKeys []OneTimeKeys
+ // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
+ // the display name for their respective device, and NOT to modify the keys. The key
+ // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
+ // Without this flag, requests to modify device display names would delete device keys.
+ OnlyDisplayNameUpdates bool
+}
+
+// PerformUploadKeysResponse is the response to PerformUploadKeys
+type PerformUploadKeysResponse struct {
+ // A fatal error when processing e.g database failures
+ Error *KeyError
+ // A map of user_id -> device_id -> Error for tracking failures.
+ KeyErrors map[string]map[string]*KeyError
+ OneTimeKeyCounts []OneTimeKeysCount
+}
+
+// PerformDeleteKeysRequest asks the keyserver to forget about certain
+// keys, and signatures related to those keys.
+type PerformDeleteKeysRequest struct {
+ UserID string
+ KeyIDs []gomatrixserverlib.KeyID
+}
+
+// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest.
+type PerformDeleteKeysResponse struct {
+ Error *KeyError
+}
+
+// KeyError sets a key error field on KeyErrors
+func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) {
+ if r.KeyErrors[userID] == nil {
+ r.KeyErrors[userID] = make(map[string]*KeyError)
+ }
+ r.KeyErrors[userID][deviceID] = err
+}
+
+type PerformClaimKeysRequest struct {
+ // Map of user_id to device_id to algorithm name
+ OneTimeKeys map[string]map[string]string
+ Timeout time.Duration
+}
+
+type PerformClaimKeysResponse struct {
+ // Map of user_id to device_id to algorithm:key_id to key JSON
+ OneTimeKeys map[string]map[string]map[string]json.RawMessage
+ // Map of remote server domain to error JSON
+ Failures map[string]interface{}
+ // Set if there was a fatal error processing this action
+ Error *KeyError
+}
+
+type PerformUploadDeviceKeysRequest struct {
+ gomatrixserverlib.CrossSigningKeys
+ // The user that uploaded the key, should be populated by the clientapi.
+ UserID string
+}
+
+type PerformUploadDeviceKeysResponse struct {
+ Error *KeyError
+}
+
+type PerformUploadDeviceSignaturesRequest struct {
+ Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice
+ // The user that uploaded the sig, should be populated by the clientapi.
+ UserID string
+}
+
+type PerformUploadDeviceSignaturesResponse struct {
+ Error *KeyError
+}
+
+type QueryKeysRequest struct {
+ // The user ID asking for the keys, e.g. if from a client API request.
+ // Will not be populated if the key request came from federation.
+ UserID string
+ // Maps user IDs to a list of devices
+ UserToDevices map[string][]string
+ Timeout time.Duration
+}
+
+type QueryKeysResponse struct {
+ // Map of remote server domain to error JSON
+ Failures map[string]interface{}
+ // Map of user_id to device_id to device_key
+ DeviceKeys map[string]map[string]json.RawMessage
+ // Maps of user_id to cross signing key
+ MasterKeys map[string]gomatrixserverlib.CrossSigningKey
+ SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
+ UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
+ // Set if there was a fatal error processing this query
+ Error *KeyError
+}
+
+type QueryKeyChangesRequest struct {
+ // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning
+ Offset int64
+ // The inclusive offset where to track key changes up to. Messages with this offset are included in the response.
+ // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing).
+ ToOffset int64
+}
+
+type QueryKeyChangesResponse struct {
+ // The set of users who have had their keys change.
+ UserIDs []string
+ // The latest offset represented in this response.
+ Offset int64
+ // Set if there was a problem handling the request.
+ Error *KeyError
+}
+
+type QueryOneTimeKeysRequest struct {
+ // The local user to query OTK counts for
+ UserID string
+ // The device to query OTK counts for
+ DeviceID string
+}
+
+type QueryOneTimeKeysResponse struct {
+ // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
+ Count OneTimeKeysCount
+ Error *KeyError
+}
+
+type QueryDeviceMessagesRequest struct {
+ UserID string
+}
+
+type QueryDeviceMessagesResponse struct {
+ // The latest stream ID
+ StreamID int64
+ Devices []DeviceMessage
+ Error *KeyError
+}
+
+type QuerySignaturesRequest struct {
+ // A map of target user ID -> target key/device IDs to retrieve signatures for
+ TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"`
+}
+
+type QuerySignaturesResponse struct {
+ // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures
+ Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap
+ // A map of target user ID -> cross-signing master key
+ MasterKeys map[string]gomatrixserverlib.CrossSigningKey
+ // A map of target user ID -> cross-signing self-signing key
+ SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
+ // A map of target user ID -> cross-signing user-signing key
+ UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
+ // The request error, if any
+ Error *KeyError
+}
+
+type PerformMarkAsStaleRequest struct {
+ UserID string
+ Domain gomatrixserverlib.ServerName
+ DeviceID string
+}
diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go
index 42ae72e7..51bd2753 100644
--- a/userapi/consumers/clientapi.go
+++ b/userapi/consumers/clientapi.go
@@ -37,7 +37,7 @@ type OutputReceiptEventConsumer struct {
jetstream nats.JetStreamContext
durable string
topic string
- db storage.Database
+ db storage.UserDatabase
serverName gomatrixserverlib.ServerName
syncProducer *producers.SyncAPI
pgClient pushgateway.Client
@@ -49,7 +49,7 @@ func NewOutputReceiptEventConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
- store storage.Database,
+ store storage.UserDatabase,
syncProducer *producers.SyncAPI,
pgClient pushgateway.Client,
) *OutputReceiptEventConsumer {
diff --git a/userapi/consumers/devicelistupdate.go b/userapi/consumers/devicelistupdate.go
new file mode 100644
index 00000000..a65889fc
--- /dev/null
+++ b/userapi/consumers/devicelistupdate.go
@@ -0,0 +1,95 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package consumers
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/userapi/internal"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
+ "github.com/sirupsen/logrus"
+
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/setup/process"
+)
+
+// DeviceListUpdateConsumer consumes device list updates that came in over federation.
+type DeviceListUpdateConsumer struct {
+ ctx context.Context
+ jetstream nats.JetStreamContext
+ durable string
+ topic string
+ updater *internal.DeviceListUpdater
+ isLocalServerName func(gomatrixserverlib.ServerName) bool
+}
+
+// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers.
+func NewDeviceListUpdateConsumer(
+ process *process.ProcessContext,
+ cfg *config.UserAPI,
+ js nats.JetStreamContext,
+ updater *internal.DeviceListUpdater,
+) *DeviceListUpdateConsumer {
+ return &DeviceListUpdateConsumer{
+ ctx: process.Context(),
+ jetstream: js,
+ durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"),
+ topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
+ updater: updater,
+ isLocalServerName: cfg.Matrix.IsLocalServerName,
+ }
+}
+
+// Start consuming from key servers
+func (t *DeviceListUpdateConsumer) Start() error {
+ return jetstream.JetStreamConsumer(
+ t.ctx, t.jetstream, t.topic, t.durable, 1,
+ t.onMessage, nats.DeliverAll(), nats.ManualAck(),
+ )
+}
+
+// onMessage is called in response to a message received on the
+// key change events topic from the key server.
+func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
+ msg := msgs[0] // Guaranteed to exist if onMessage is called
+ var m gomatrixserverlib.DeviceListUpdateEvent
+ if err := json.Unmarshal(msg.Data, &m); err != nil {
+ logrus.WithError(err).Errorf("Failed to read from device list update input topic")
+ return true
+ }
+ origin := gomatrixserverlib.ServerName(msg.Header.Get("origin"))
+ if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil {
+ return true
+ } else if t.isLocalServerName(serverName) {
+ return true
+ } else if serverName != origin {
+ return true
+ }
+
+ err := t.updater.Update(ctx, m)
+ if err != nil {
+ logrus.WithFields(logrus.Fields{
+ "user_id": m.UserID,
+ "device_id": m.DeviceID,
+ "stream_id": m.StreamID,
+ "prev_id": m.PrevID,
+ }).WithError(err).Errorf("Failed to update device list")
+ return false
+ }
+ return true
+}
diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go
index 3ce5af62..47d33095 100644
--- a/userapi/consumers/roomserver.go
+++ b/userapi/consumers/roomserver.go
@@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct {
rsAPI rsapi.UserRoomserverAPI
jetstream nats.JetStreamContext
durable string
- db storage.Database
+ db storage.UserDatabase
topic string
pgClient pushgateway.Client
syncProducer *producers.SyncAPI
@@ -53,7 +53,7 @@ func NewOutputRoomEventConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
- store storage.Database,
+ store storage.UserDatabase,
pgClient pushgateway.Client,
rsAPI rsapi.UserRoomserverAPI,
syncProducer *producers.SyncAPI,
diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go
index 39f4aab4..bc5ae652 100644
--- a/userapi/consumers/roomserver_test.go
+++ b/userapi/consumers/roomserver_test.go
@@ -18,11 +18,11 @@ import (
userAPITypes "github.com/matrix-org/dendrite/userapi/types"
)
-func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
+func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
- db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
+ db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "", 4, 0, 0, "")
if err != nil {
diff --git a/userapi/consumers/signingkeyupdate.go b/userapi/consumers/signingkeyupdate.go
new file mode 100644
index 00000000..f4ff017d
--- /dev/null
+++ b/userapi/consumers/signingkeyupdate.go
@@ -0,0 +1,111 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package consumers
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
+ "github.com/sirupsen/logrus"
+
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/userapi/api"
+)
+
+// SigningKeyUpdateConsumer consumes signing key updates that came in over federation.
+type SigningKeyUpdateConsumer struct {
+ ctx context.Context
+ jetstream nats.JetStreamContext
+ durable string
+ topic string
+ userAPI api.UploadDeviceKeysAPI
+ cfg *config.UserAPI
+ isLocalServerName func(gomatrixserverlib.ServerName) bool
+}
+
+// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers.
+func NewSigningKeyUpdateConsumer(
+ process *process.ProcessContext,
+ cfg *config.UserAPI,
+ js nats.JetStreamContext,
+ userAPI api.UploadDeviceKeysAPI,
+) *SigningKeyUpdateConsumer {
+ return &SigningKeyUpdateConsumer{
+ ctx: process.Context(),
+ jetstream: js,
+ durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"),
+ topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
+ userAPI: userAPI,
+ cfg: cfg,
+ isLocalServerName: cfg.Matrix.IsLocalServerName,
+ }
+}
+
+// Start consuming from key servers
+func (t *SigningKeyUpdateConsumer) Start() error {
+ return jetstream.JetStreamConsumer(
+ t.ctx, t.jetstream, t.topic, t.durable, 1,
+ t.onMessage, nats.DeliverAll(), nats.ManualAck(),
+ )
+}
+
+// onMessage is called in response to a message received on the
+// signing key update events topic from the key server.
+func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
+ msg := msgs[0] // Guaranteed to exist if onMessage is called
+ var updatePayload api.CrossSigningKeyUpdate
+ if err := json.Unmarshal(msg.Data, &updatePayload); err != nil {
+ logrus.WithError(err).Errorf("Failed to read from signing key update input topic")
+ return true
+ }
+ origin := gomatrixserverlib.ServerName(msg.Header.Get("origin"))
+ if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil {
+ logrus.WithError(err).Error("failed to split user id")
+ return true
+ } else if t.isLocalServerName(serverName) {
+ logrus.Warn("dropping device key update from ourself")
+ return true
+ } else if serverName != origin {
+ logrus.Warnf("dropping device key update, %s != %s", serverName, origin)
+ return true
+ }
+
+ keys := gomatrixserverlib.CrossSigningKeys{}
+ if updatePayload.MasterKey != nil {
+ keys.MasterKey = *updatePayload.MasterKey
+ }
+ if updatePayload.SelfSigningKey != nil {
+ keys.SelfSigningKey = *updatePayload.SelfSigningKey
+ }
+ uploadReq := &api.PerformUploadDeviceKeysRequest{
+ CrossSigningKeys: keys,
+ UserID: updatePayload.UserID,
+ }
+ uploadRes := &api.PerformUploadDeviceKeysResponse{}
+ if err := t.userAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil {
+ logrus.WithError(err).Error("failed to upload device keys")
+ return false
+ }
+ if uploadRes.Error != nil {
+ logrus.WithError(uploadRes.Error).Error("failed to upload device keys")
+ return true
+ }
+
+ return true
+}
diff --git a/userapi/internal/cross_signing.go b/userapi/internal/cross_signing.go
new file mode 100644
index 00000000..8b9704d1
--- /dev/null
+++ b/userapi/internal/cross_signing.go
@@ -0,0 +1,587 @@
+// Copyright 2021 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ "bytes"
+ "context"
+ "crypto/ed25519"
+ "database/sql"
+ "fmt"
+ "strings"
+
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+ "golang.org/x/crypto/curve25519"
+)
+
+func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpose gomatrixserverlib.CrossSigningKeyPurpose) error {
+ // Is there exactly one key?
+ if len(key.Keys) != 1 {
+ return fmt.Errorf("should contain exactly one key")
+ }
+
+ // Does the key ID match the key value? Iterates exactly once
+ for keyID, keyData := range key.Keys {
+ b64 := keyData.Encode()
+ tokens := strings.Split(string(keyID), ":")
+ if len(tokens) != 2 {
+ return fmt.Errorf("key ID is incorrectly formatted")
+ }
+ if tokens[1] != b64 {
+ return fmt.Errorf("key ID isn't correct")
+ }
+ switch tokens[0] {
+ case "ed25519":
+ if len(keyData) != ed25519.PublicKeySize {
+ return fmt.Errorf("ed25519 key is not the correct length")
+ }
+ case "curve25519":
+ if len(keyData) != curve25519.PointSize {
+ return fmt.Errorf("curve25519 key is not the correct length")
+ }
+ default:
+ // We can't enforce the key length to be correct for an
+ // algorithm that we don't recognise, so instead we'll
+ // just make sure that it isn't incredibly excessive.
+ if l := len(keyData); l > 4096 {
+ return fmt.Errorf("unknown key type is too long (%d bytes)", l)
+ }
+ }
+ }
+
+ // Check to see if the signatures make sense
+ for _, forOriginUser := range key.Signatures {
+ for originKeyID, originSignature := range forOriginUser {
+ switch strings.SplitN(string(originKeyID), ":", 1)[0] {
+ case "ed25519":
+ if len(originSignature) != ed25519.SignatureSize {
+ return fmt.Errorf("ed25519 signature is not the correct length")
+ }
+ case "curve25519":
+ return fmt.Errorf("curve25519 signatures are impossible")
+ default:
+ if l := len(originSignature); l > 4096 {
+ return fmt.Errorf("unknown signature type is too long (%d bytes)", l)
+ }
+ }
+ }
+ }
+
+ // Does the key claim to be from the right user?
+ if userID != key.UserID {
+ return fmt.Errorf("key has a user ID mismatch")
+ }
+
+ // Does the key contain the correct purpose?
+ useful := false
+ for _, usage := range key.Usage {
+ if usage == purpose {
+ useful = true
+ break
+ }
+ }
+ if !useful {
+ return fmt.Errorf("key does not contain correct usage purpose")
+ }
+
+ return nil
+}
+
+// nolint:gocyclo
+func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
+ // Find the keys to store.
+ byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
+ toStore := types.CrossSigningKeyMap{}
+ hasMasterKey := false
+
+ if len(req.MasterKey.Keys) > 0 {
+ if err := sanityCheckKey(req.MasterKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err != nil {
+ res.Error = &api.KeyError{
+ Err: "Master key sanity check failed: " + err.Error(),
+ IsInvalidParam: true,
+ }
+ return nil
+ }
+
+ byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey
+ for _, key := range req.MasterKey.Keys { // iterates once, see sanityCheckKey
+ toStore[gomatrixserverlib.CrossSigningKeyPurposeMaster] = key
+ }
+ hasMasterKey = true
+ }
+
+ if len(req.SelfSigningKey.Keys) > 0 {
+ if err := sanityCheckKey(req.SelfSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err != nil {
+ res.Error = &api.KeyError{
+ Err: "Self-signing key sanity check failed: " + err.Error(),
+ IsInvalidParam: true,
+ }
+ return nil
+ }
+
+ byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey
+ for _, key := range req.SelfSigningKey.Keys { // iterates once, see sanityCheckKey
+ toStore[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = key
+ }
+ }
+
+ if len(req.UserSigningKey.Keys) > 0 {
+ if err := sanityCheckKey(req.UserSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeUserSigning); err != nil {
+ res.Error = &api.KeyError{
+ Err: "User-signing key sanity check failed: " + err.Error(),
+ IsInvalidParam: true,
+ }
+ return nil
+ }
+
+ byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey
+ for _, key := range req.UserSigningKey.Keys { // iterates once, see sanityCheckKey
+ toStore[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = key
+ }
+ }
+
+ // If there's nothing to do then stop here.
+ if len(toStore) == 0 {
+ res.Error = &api.KeyError{
+ Err: "No keys were supplied in the request",
+ IsMissingParam: true,
+ }
+ return nil
+ }
+
+ // We can't have a self-signing or user-signing key without a master
+ // key, so make sure we have one of those. We will also only actually do
+ // something if any of the specified keys in the request are different
+ // to what we've got in the database, to avoid generating key change
+ // notifications unnecessarily.
+ existingKeys, err := a.KeyDatabase.CrossSigningKeysDataForUser(ctx, req.UserID)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
+ }
+ return nil
+ }
+
+ // If we still can't find a master key for the user then stop the upload.
+ // This satisfies the "Fails to upload self-signing key without master key" test.
+ if !hasMasterKey {
+ if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey {
+ res.Error = &api.KeyError{
+ Err: "No master key was found",
+ IsMissingParam: true,
+ }
+ return nil
+ }
+ }
+
+ // Check if anything actually changed compared to what we have in the database.
+ changed := false
+ for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{
+ gomatrixserverlib.CrossSigningKeyPurposeMaster,
+ gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
+ gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
+ } {
+ old, gotOld := existingKeys[purpose]
+ new, gotNew := toStore[purpose]
+ if gotOld != gotNew {
+ // A new key purpose has been specified that we didn't know before,
+ // or one has been removed.
+ changed = true
+ break
+ }
+ if !bytes.Equal(old, new) {
+ // One of the existing keys for a purpose we already knew about has
+ // changed.
+ changed = true
+ break
+ }
+ }
+ if !changed {
+ return nil
+ }
+
+ // Store the keys.
+ if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err),
+ }
+ return nil
+ }
+
+ // Now upload any signatures that were included with the keys.
+ for _, key := range byPurpose {
+ var targetKeyID gomatrixserverlib.KeyID
+ for targetKey := range key.Keys { // iterates once, see sanityCheckKey
+ targetKeyID = targetKey
+ }
+ for sigUserID, forSigUserID := range key.Signatures {
+ if sigUserID != req.UserID {
+ continue
+ }
+ for sigKeyID, sigBytes := range forSigUserID {
+ if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err),
+ }
+ return nil
+ }
+ }
+ }
+ }
+
+ // Finally, generate a notification that we updated the keys.
+ update := api.CrossSigningKeyUpdate{
+ UserID: req.UserID,
+ }
+ if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok {
+ update.MasterKey = &mk
+ }
+ if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok {
+ update.SelfSigningKey = &ssk
+ }
+ if update.MasterKey == nil && update.SelfSigningKey == nil {
+ return nil
+ }
+ if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
+ }
+ return nil
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error {
+ // Before we do anything, we need the master and self-signing keys for this user.
+ // Then we can verify the signatures make sense.
+ queryReq := &api.QueryKeysRequest{
+ UserID: req.UserID,
+ UserToDevices: map[string][]string{},
+ }
+ queryRes := &api.QueryKeysResponse{}
+ for userID := range req.Signatures {
+ queryReq.UserToDevices[userID] = []string{}
+ }
+ _ = a.QueryKeys(ctx, queryReq, queryRes)
+
+ selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
+ otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
+
+ // Sort signatures into two groups: one where people have signed their own
+ // keys and one where people have signed someone elses
+ for userID, forUserID := range req.Signatures {
+ for keyID, keyOrDevice := range forUserID {
+ switch key := keyOrDevice.CrossSigningBody.(type) {
+ case *gomatrixserverlib.CrossSigningKey:
+ if key.UserID == req.UserID {
+ if _, ok := selfSignatures[userID]; !ok {
+ selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
+ }
+ selfSignatures[userID][keyID] = keyOrDevice
+ } else {
+ if _, ok := otherSignatures[userID]; !ok {
+ otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
+ }
+ otherSignatures[userID][keyID] = keyOrDevice
+ }
+
+ case *gomatrixserverlib.DeviceKeys:
+ if key.UserID == req.UserID {
+ if _, ok := selfSignatures[userID]; !ok {
+ selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
+ }
+ selfSignatures[userID][keyID] = keyOrDevice
+ } else {
+ if _, ok := otherSignatures[userID]; !ok {
+ otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
+ }
+ otherSignatures[userID][keyID] = keyOrDevice
+ }
+
+ default:
+ continue
+ }
+ }
+ }
+
+ if err := a.processSelfSignatures(ctx, selfSignatures); err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("a.processSelfSignatures: %s", err),
+ }
+ return nil
+ }
+
+ if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("a.processOtherSignatures: %s", err),
+ }
+ return nil
+ }
+
+ // Finally, generate a notification that we updated the signatures.
+ for userID := range req.Signatures {
+ masterKey := queryRes.MasterKeys[userID]
+ selfSigningKey := queryRes.SelfSigningKeys[userID]
+ update := api.CrossSigningKeyUpdate{
+ UserID: userID,
+ MasterKey: &masterKey,
+ SelfSigningKey: &selfSigningKey,
+ }
+ if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
+ }
+ return nil
+ }
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) processSelfSignatures(
+ ctx context.Context,
+ signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
+) error {
+ // Here we will process:
+ // * The user signing their own devices using their self-signing key
+ // * The user signing their master key using one of their devices
+
+ for targetUserID, forTargetUserID := range signatures {
+ for targetKeyID, signature := range forTargetUserID {
+ switch sig := signature.CrossSigningBody.(type) {
+ case *gomatrixserverlib.CrossSigningKey:
+ for keyID := range sig.Keys {
+ split := strings.SplitN(string(keyID), ":", 2)
+ if len(split) > 1 && gomatrixserverlib.KeyID(split[1]) == targetKeyID {
+ targetKeyID = keyID // contains the ed25519: or other scheme
+ break
+ }
+ }
+ for originUserID, forOriginUserID := range sig.Signatures {
+ for originKeyID, originSig := range forOriginUserID {
+ if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
+ ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
+ ); err != nil {
+ return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
+ }
+ }
+ }
+
+ case *gomatrixserverlib.DeviceKeys:
+ for originUserID, forOriginUserID := range sig.Signatures {
+ for originKeyID, originSig := range forOriginUserID {
+ if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
+ ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
+ ); err != nil {
+ return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
+ }
+ }
+ }
+
+ default:
+ return fmt.Errorf("unexpected type assertion")
+ }
+ }
+ }
+
+ return nil
+}
+
+func (a *UserInternalAPI) processOtherSignatures(
+ ctx context.Context, userID string, queryRes *api.QueryKeysResponse,
+ signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
+) error {
+ // Here we will process:
+ // * A user signing someone else's master keys using their user-signing keys
+
+ for targetUserID, forTargetUserID := range signatures {
+ for _, signature := range forTargetUserID {
+ switch sig := signature.CrossSigningBody.(type) {
+ case *gomatrixserverlib.CrossSigningKey:
+ // Find the local copy of the master key. We'll use this to be
+ // sure that the supplied stanza matches the key that we think it
+ // should be.
+ masterKey, ok := queryRes.MasterKeys[targetUserID]
+ if !ok {
+ return fmt.Errorf("failed to find master key for user %q", targetUserID)
+ }
+
+ // For each key ID, write the signatures. Maybe there'll be more
+ // than one algorithm in the future so it's best not to focus on
+ // everything being ed25519:.
+ for targetKeyID, suppliedKeyData := range sig.Keys {
+ // The master key will be supplied in the request, but we should
+ // make sure that it matches what we think the master key should
+ // actually be.
+ localKeyData, lok := masterKey.Keys[targetKeyID]
+ if !lok {
+ return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID)
+ } else if !bytes.Equal(suppliedKeyData, localKeyData) {
+ return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID)
+ }
+
+ // We only care about the signatures from the uploading user, so
+ // we will ignore anything that didn't originate from them.
+ userSigs, ok := sig.Signatures[userID]
+ if !ok {
+ return fmt.Errorf("there are no signatures on master key %q from uploading user %q", targetKeyID, userID)
+ }
+
+ for originKeyID, originSig := range userSigs {
+ if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
+ ctx, userID, originKeyID, targetUserID, targetKeyID, originSig,
+ ); err != nil {
+ return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
+ }
+ }
+ }
+
+ default:
+ // Users should only be signing another person's master key,
+ // so if we're here, it's probably because it's actually a
+ // gomatrixserverlib.DeviceKeys, which doesn't make sense.
+ }
+ }
+ }
+
+ return nil
+}
+
+func (a *UserInternalAPI) crossSigningKeysFromDatabase(
+ ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse,
+) {
+ for targetUserID := range req.UserToDevices {
+ keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
+ if err != nil {
+ logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID)
+ continue
+ }
+
+ for keyType, key := range keys {
+ var keyID gomatrixserverlib.KeyID
+ for id := range key.Keys {
+ keyID = id
+ break
+ }
+
+ sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
+ if err != nil && err != sql.ErrNoRows {
+ logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID)
+ continue
+ }
+
+ appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) {
+ if key.Signatures == nil {
+ key.Signatures = types.CrossSigningSigMap{}
+ }
+ if _, ok := key.Signatures[originUserID]; !ok {
+ key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes)
+ }
+ key.Signatures[originUserID][originKeyID] = signature
+ }
+
+ for originUserID, forOrigin := range sigMap {
+ for originKeyID, signature := range forOrigin {
+ switch {
+ case req.UserID != "" && originUserID == req.UserID:
+ // Include signatures that we created
+ appendSignature(originUserID, originKeyID, signature)
+ case originUserID == targetUserID:
+ // Include signatures that were created by the person whose key
+ // we are processing
+ appendSignature(originUserID, originKeyID, signature)
+ }
+ }
+ }
+
+ switch keyType {
+ case gomatrixserverlib.CrossSigningKeyPurposeMaster:
+ res.MasterKeys[targetUserID] = key
+
+ case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning:
+ res.SelfSigningKeys[targetUserID] = key
+
+ case gomatrixserverlib.CrossSigningKeyPurposeUserSigning:
+ res.UserSigningKeys[targetUserID] = key
+ }
+ }
+ }
+}
+
+func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error {
+ for targetUserID, forTargetUser := range req.TargetIDs {
+ keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
+ if err != nil && err != sql.ErrNoRows {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err),
+ }
+ continue
+ }
+
+ for targetPurpose, targetKey := range keyMap {
+ switch targetPurpose {
+ case gomatrixserverlib.CrossSigningKeyPurposeMaster:
+ if res.MasterKeys == nil {
+ res.MasterKeys = map[string]gomatrixserverlib.CrossSigningKey{}
+ }
+ res.MasterKeys[targetUserID] = targetKey
+
+ case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning:
+ if res.SelfSigningKeys == nil {
+ res.SelfSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{}
+ }
+ res.SelfSigningKeys[targetUserID] = targetKey
+
+ case gomatrixserverlib.CrossSigningKeyPurposeUserSigning:
+ if res.UserSigningKeys == nil {
+ res.UserSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{}
+ }
+ res.UserSigningKeys[targetUserID] = targetKey
+ }
+ }
+
+ for _, targetKeyID := range forTargetUser {
+ // Get own signatures only.
+ sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID)
+ if err != nil && err != sql.ErrNoRows {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),
+ }
+ return nil
+ }
+
+ for sourceUserID, forSourceUser := range sigMap {
+ for sourceKeyID, sourceSig := range forSourceUser {
+ if res.Signatures == nil {
+ res.Signatures = map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{}
+ }
+ if _, ok := res.Signatures[targetUserID]; !ok {
+ res.Signatures[targetUserID] = map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{}
+ }
+ if _, ok := res.Signatures[targetUserID][targetKeyID]; !ok {
+ res.Signatures[targetUserID][targetKeyID] = types.CrossSigningSigMap{}
+ }
+ if _, ok := res.Signatures[targetUserID][targetKeyID][sourceUserID]; !ok {
+ res.Signatures[targetUserID][targetKeyID][sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ res.Signatures[targetUserID][targetKeyID][sourceUserID][sourceKeyID] = sourceSig
+ }
+ }
+ }
+ }
+ return nil
+}
diff --git a/userapi/internal/device_list_update.go b/userapi/internal/device_list_update.go
new file mode 100644
index 00000000..3b4dcf98
--- /dev/null
+++ b/userapi/internal/device_list_update.go
@@ -0,0 +1,579 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "hash/fnv"
+ "net"
+ "sync"
+ "time"
+
+ rsapi "github.com/matrix-org/dendrite/roomserver/api"
+
+ "github.com/matrix-org/gomatrix"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/sirupsen/logrus"
+
+ fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
+ "github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/userapi/api"
+)
+
+var (
+ deviceListUpdateCount = prometheus.NewCounterVec(
+ prometheus.CounterOpts{
+ Namespace: "dendrite",
+ Subsystem: "keyserver",
+ Name: "device_list_update",
+ Help: "Number of times we have attempted to update device lists from this server",
+ },
+ []string{"server"},
+ )
+)
+
+const requestTimeout = time.Second * 30
+
+func init() {
+ prometheus.MustRegister(
+ deviceListUpdateCount,
+ )
+}
+
+// DeviceListUpdater handles device list updates from remote servers.
+//
+// In the case where we have the prev_id for an update, the updater just stores the update (after acquiring a per-user lock).
+// In the case where we do not have the prev_id for an update, the updater marks the user_id as stale and notifies
+// a worker to get the latest device list for this user. Note: stream IDs are scoped per user so missing a prev_id
+// for a (user, device) does not mean that DEVICE is outdated as the previous ID could be for a different device:
+// we have to invalidate all devices for that user. Once the list has been fetched, the per-user lock is acquired and the
+// updater stores the latest list along with the latest stream ID.
+//
+// On startup, the updater spins up N workers which are responsible for querying device keys from remote servers.
+// Workers are scoped by homeserver domain, with one worker responsible for many domains, determined by hashing
+// mod N the server name. Work is sent via a channel which just serves to "poke" the worker as the data is retrieved
+// from the database (which allows us to batch requests to the same server). This has a number of desirable properties:
+// - We guarantee only 1 in-flight /keys/query request per server at any time as there is exactly 1 worker responsible
+// for that domain.
+// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where
+// we have many many servers)
+// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers.
+//
+// The downsides are that:
+// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free
+// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts)
+// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests
+// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse
+// than being stuck behind foo.bar
+//
+// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is
+// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried.
+type DeviceListUpdater struct {
+ process *process.ProcessContext
+ // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1
+ // request to the remote server and race.
+ // TODO: Put in an LRU cache to bound growth
+ userIDToMutex map[string]*sync.Mutex
+ mu *sync.Mutex // protects UserIDToMutex
+
+ db DeviceListUpdaterDatabase
+ api DeviceListUpdaterAPI
+ producer KeyChangeProducer
+ fedClient fedsenderapi.KeyserverFederationAPI
+ workerChans []chan gomatrixserverlib.ServerName
+ thisServer gomatrixserverlib.ServerName
+
+ // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
+ // block on or timeout via a select.
+ userIDToChan map[string]chan bool
+ userIDToChanMu *sync.Mutex
+ rsAPI rsapi.KeyserverRoomserverAPI
+}
+
+// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
+// Useful for testing.
+type DeviceListUpdaterDatabase interface {
+ // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
+ // If no domains are given, all user IDs with stale device lists are returned.
+ StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
+
+ // MarkDeviceListStale sets the stale bit for this user to isStale.
+ MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
+
+ // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
+ // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
+ // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
+ StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
+
+ // PrevIDsExists returns true if all prev IDs exist for this user.
+ PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
+
+ // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
+ DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
+
+ DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error
+}
+
+type DeviceListUpdaterAPI interface {
+ PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error
+}
+
+// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
+type KeyChangeProducer interface {
+ ProduceKeyChanges(keys []api.DeviceMessage) error
+}
+
+// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
+func NewDeviceListUpdater(
+ process *process.ProcessContext, db DeviceListUpdaterDatabase,
+ api DeviceListUpdaterAPI, producer KeyChangeProducer,
+ fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
+ rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName,
+) *DeviceListUpdater {
+ return &DeviceListUpdater{
+ process: process,
+ userIDToMutex: make(map[string]*sync.Mutex),
+ mu: &sync.Mutex{},
+ db: db,
+ api: api,
+ producer: producer,
+ fedClient: fedClient,
+ thisServer: thisServer,
+ workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
+ userIDToChan: make(map[string]chan bool),
+ userIDToChanMu: &sync.Mutex{},
+ rsAPI: rsAPI,
+ }
+}
+
+// Start the device list updater, which will try to refresh any stale device lists.
+func (u *DeviceListUpdater) Start() error {
+ for i := 0; i < len(u.workerChans); i++ {
+ // Allocate a small buffer per channel.
+ // If the buffer limit is reached, backpressure will cause the processing of EDUs
+ // to stop (in this transaction) until key requests can be made.
+ ch := make(chan gomatrixserverlib.ServerName, 10)
+ u.workerChans[i] = ch
+ go u.worker(ch)
+ }
+
+ staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
+ if err != nil {
+ return err
+ }
+ offset, step := time.Second*10, time.Second
+ if max := len(staleLists); max > 120 {
+ step = (time.Second * 120) / time.Duration(max)
+ }
+ for _, userID := range staleLists {
+ userID := userID // otherwise we are only sending the last entry
+ time.AfterFunc(offset, func() {
+ u.notifyWorkers(userID)
+ })
+ offset += step
+ }
+ return nil
+}
+
+// CleanUp removes stale device entries for users we don't share a room with anymore
+func (u *DeviceListUpdater) CleanUp() error {
+ staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
+ if err != nil {
+ return err
+ }
+
+ res := rsapi.QueryLeftUsersResponse{}
+ if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil {
+ return err
+ }
+
+ if len(res.LeftUsers) == 0 {
+ return nil
+ }
+ logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers))
+ return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers)
+}
+
+func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ if u.userIDToMutex[userID] == nil {
+ u.userIDToMutex[userID] = &sync.Mutex{}
+ }
+ return u.userIDToMutex[userID]
+}
+
+// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it.
+// Blocks until the device list is synced or the timeout is reached.
+func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error {
+ mu := u.mutex(userID)
+ mu.Lock()
+ err := u.db.MarkDeviceListStale(ctx, userID, true)
+ mu.Unlock()
+ if err != nil {
+ return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err)
+ }
+ u.notifyWorkers(userID)
+ return nil
+}
+
+// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest,
+// which assumes when /send 200 OKs that the device lists have been updated.
+func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
+ isDeviceListStale, err := u.update(ctx, event)
+ if err != nil {
+ return err
+ }
+ if isDeviceListStale {
+ // poke workers to handle stale device lists
+ u.notifyWorkers(event.UserID)
+ }
+ return nil
+}
+
+func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) (bool, error) {
+ mu := u.mutex(event.UserID)
+ mu.Lock()
+ defer mu.Unlock()
+ // check if we have the prev IDs
+ exists, err := u.db.PrevIDsExists(ctx, event.UserID, event.PrevID)
+ if err != nil {
+ return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
+ }
+ // if this is the first time we're hearing about this user, sync the device list manually.
+ if len(event.PrevID) == 0 {
+ exists = false
+ }
+ util.GetLogger(ctx).WithFields(logrus.Fields{
+ "prev_ids_exist": exists,
+ "user_id": event.UserID,
+ "device_id": event.DeviceID,
+ "stream_id": event.StreamID,
+ "prev_ids": event.PrevID,
+ "display_name": event.DeviceDisplayName,
+ "deleted": event.Deleted,
+ }).Trace("DeviceListUpdater.Update")
+
+ // if we haven't missed anything update the database and notify users
+ if exists || event.Deleted {
+ k := event.Keys
+ if event.Deleted {
+ k = nil
+ }
+ keys := []api.DeviceMessage{
+ {
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: event.DeviceID,
+ DisplayName: event.DeviceDisplayName,
+ KeyJSON: k,
+ UserID: event.UserID,
+ },
+ StreamID: event.StreamID,
+ },
+ }
+
+ // DeviceKeysJSON will side-effect modify this, so it needs
+ // to be a copy, not sharing any pointers with the above.
+ deviceKeysCopy := *keys[0].DeviceKeys
+ deviceKeysCopy.KeyJSON = nil
+ existingKeys := []api.DeviceMessage{
+ {
+ Type: keys[0].Type,
+ DeviceKeys: &deviceKeysCopy,
+ StreamID: keys[0].StreamID,
+ },
+ }
+
+ // fetch what keys we had already and only emit changes
+ if err = u.db.DeviceKeysJSON(ctx, existingKeys); err != nil {
+ // non-fatal, log and continue
+ util.GetLogger(ctx).WithError(err).WithField("user_id", event.UserID).Errorf(
+ "failed to query device keys json for calculating diffs",
+ )
+ }
+
+ err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
+ if err != nil {
+ return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
+ }
+
+ if err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false); err != nil {
+ return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err)
+ }
+ return false, nil
+ }
+
+ err = u.db.MarkDeviceListStale(ctx, event.UserID, true)
+ if err != nil {
+ return false, fmt.Errorf("failed to mark device list for %s as stale: %w", event.UserID, err)
+ }
+
+ return true, nil
+}
+
+func (u *DeviceListUpdater) notifyWorkers(userID string) {
+ _, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return
+ }
+ hash := fnv.New32a()
+ _, _ = hash.Write([]byte(remoteServer))
+ index := int(int64(hash.Sum32()) % int64(len(u.workerChans)))
+
+ ch := u.assignChannel(userID)
+ u.workerChans[index] <- remoteServer
+ select {
+ case <-ch:
+ case <-time.After(10 * time.Second):
+ // we don't return an error in this case as it's not a failure condition.
+ // we mainly block for the benefit of sytest anyway
+ }
+}
+
+func (u *DeviceListUpdater) assignChannel(userID string) chan bool {
+ u.userIDToChanMu.Lock()
+ defer u.userIDToChanMu.Unlock()
+ if ch, ok := u.userIDToChan[userID]; ok {
+ return ch
+ }
+ ch := make(chan bool)
+ u.userIDToChan[userID] = ch
+ return ch
+}
+
+func (u *DeviceListUpdater) clearChannel(userID string) {
+ u.userIDToChanMu.Lock()
+ defer u.userIDToChanMu.Unlock()
+ if ch, ok := u.userIDToChan[userID]; ok {
+ close(ch)
+ delete(u.userIDToChan, userID)
+ }
+}
+
+func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
+ retries := make(map[gomatrixserverlib.ServerName]time.Time)
+ retriesMu := &sync.Mutex{}
+ // restarter goroutine which will inject failed servers into ch when it is time
+ go func() {
+ var serversToRetry []gomatrixserverlib.ServerName
+ for {
+ serversToRetry = serversToRetry[:0] // reuse memory
+ time.Sleep(time.Second)
+ retriesMu.Lock()
+ now := time.Now()
+ for srv, retryAt := range retries {
+ if now.After(retryAt) {
+ serversToRetry = append(serversToRetry, srv)
+ }
+ }
+ for _, srv := range serversToRetry {
+ delete(retries, srv)
+ }
+ retriesMu.Unlock()
+ for _, srv := range serversToRetry {
+ ch <- srv
+ }
+ }
+ }()
+ for serverName := range ch {
+ retriesMu.Lock()
+ _, exists := retries[serverName]
+ retriesMu.Unlock()
+ if exists {
+ // Don't retry a server that we're already waiting for.
+ continue
+ }
+ waitTime, shouldRetry := u.processServer(serverName)
+ if shouldRetry {
+ retriesMu.Lock()
+ if _, exists = retries[serverName]; !exists {
+ retries[serverName] = time.Now().Add(waitTime)
+ }
+ retriesMu.Unlock()
+ }
+ }
+}
+
+func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) {
+ ctx := u.process.Context()
+ logger := util.GetLogger(ctx).WithField("server_name", serverName)
+ deviceListUpdateCount.WithLabelValues(string(serverName)).Inc()
+
+ waitTime := defaultWaitTime // How long should we wait to try again?
+ successCount := 0 // How many user requests failed?
+
+ userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName})
+ if err != nil {
+ logger.WithError(err).Error("Failed to load stale device lists")
+ return waitTime, true
+ }
+
+ defer func() {
+ for _, userID := range userIDs {
+ // always clear the channel to unblock Update calls regardless of success/failure
+ u.clearChannel(userID)
+ }
+ }()
+
+ for _, userID := range userIDs {
+ userWait, err := u.processServerUser(ctx, serverName, userID)
+ if err != nil {
+ if userWait > waitTime {
+ waitTime = userWait
+ }
+ break
+ }
+ successCount++
+ }
+
+ allUsersSucceeded := successCount == len(userIDs)
+ if !allUsersSucceeded {
+ logger.WithFields(logrus.Fields{
+ "total": len(userIDs),
+ "succeeded": successCount,
+ "failed": len(userIDs) - successCount,
+ "wait_time": waitTime,
+ }).Debug("Failed to query device keys for some users")
+ }
+ return waitTime, !allUsersSucceeded
+}
+
+func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) {
+ ctx, cancel := context.WithTimeout(ctx, requestTimeout)
+ defer cancel()
+ logger := util.GetLogger(ctx).WithFields(logrus.Fields{
+ "server_name": serverName,
+ "user_id": userID,
+ })
+ res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID)
+ if err != nil {
+ if errors.Is(err, context.DeadlineExceeded) {
+ return time.Minute * 10, err
+ }
+ switch e := err.(type) {
+ case *json.UnmarshalTypeError, *json.SyntaxError:
+ logger.WithError(err).Debugf("Device list update for %q contained invalid JSON", userID)
+ return defaultWaitTime, nil
+ case *fedsenderapi.FederationClientError:
+ if e.RetryAfter > 0 {
+ return e.RetryAfter, err
+ } else if e.Blacklisted {
+ return time.Hour * 8, err
+ }
+ case net.Error:
+ // Use the default waitTime, if it's a timeout.
+ // It probably doesn't make sense to try further users.
+ if !e.Timeout() {
+ logger.WithError(e).Debug("GetUserDevices returned net.Error")
+ return time.Minute * 10, err
+ }
+ case gomatrix.HTTPError:
+ // The remote server returned an error, give it some time to recover.
+ // This is to avoid spamming remote servers, which may not be Matrix servers anymore.
+ if e.Code >= 300 {
+ logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError")
+ return hourWaitTime, err
+ }
+ default:
+ // Something else failed
+ logger.WithError(err).Debugf("GetUserDevices returned unknown error type: %T", err)
+ return time.Minute * 10, err
+ }
+ }
+ if res.UserID != userID {
+ logger.WithError(err).Debugf("User ID %q in device list update response doesn't match expected %q", res.UserID, userID)
+ return defaultWaitTime, nil
+ }
+ if res.MasterKey != nil || res.SelfSigningKey != nil {
+ uploadReq := &api.PerformUploadDeviceKeysRequest{
+ UserID: userID,
+ }
+ uploadRes := &api.PerformUploadDeviceKeysResponse{}
+ if res.MasterKey != nil {
+ if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil {
+ uploadReq.MasterKey = *res.MasterKey
+ }
+ }
+ if res.SelfSigningKey != nil {
+ if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil {
+ uploadReq.SelfSigningKey = *res.SelfSigningKey
+ }
+ }
+ _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
+ }
+ err = u.updateDeviceList(&res)
+ if err != nil {
+ logger.WithError(err).Error("Fetched device list but failed to store/emit it")
+ return defaultWaitTime, err
+ }
+ return defaultWaitTime, nil
+}
+
+func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error {
+ ctx := context.Background() // we've got the keys, don't time out when persisting them to the database.
+ keys := make([]api.DeviceMessage, len(res.Devices))
+ existingKeys := make([]api.DeviceMessage, len(res.Devices))
+ for i, device := range res.Devices {
+ keyJSON, err := json.Marshal(device.Keys)
+ if err != nil {
+ util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device")
+ continue
+ }
+ keys[i] = api.DeviceMessage{
+ Type: api.TypeDeviceKeyUpdate,
+ StreamID: res.StreamID,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: device.DeviceID,
+ DisplayName: device.DisplayName,
+ UserID: res.UserID,
+ KeyJSON: keyJSON,
+ },
+ }
+ existingKeys[i] = api.DeviceMessage{
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ UserID: res.UserID,
+ DeviceID: device.DeviceID,
+ },
+ }
+ }
+ // fetch what keys we had already and only emit changes
+ if err := u.db.DeviceKeysJSON(ctx, existingKeys); err != nil {
+ // non-fatal, log and continue
+ util.GetLogger(ctx).WithError(err).WithField("user_id", res.UserID).Errorf(
+ "failed to query device keys json for calculating diffs",
+ )
+ }
+
+ err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID})
+ if err != nil {
+ return fmt.Errorf("failed to store remote device keys: %w", err)
+ }
+ err = u.db.MarkDeviceListStale(ctx, res.UserID, false)
+ if err != nil {
+ return fmt.Errorf("failed to mark device list as fresh: %w", err)
+ }
+ err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false)
+ if err != nil {
+ return fmt.Errorf("failed to emit key changes for fresh device list: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/internal/device_list_update_default.go b/userapi/internal/device_list_update_default.go
new file mode 100644
index 00000000..7d357c95
--- /dev/null
+++ b/userapi/internal/device_list_update_default.go
@@ -0,0 +1,22 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build !vw
+
+package internal
+
+import "time"
+
+const defaultWaitTime = time.Minute
+const hourWaitTime = time.Hour
diff --git a/userapi/internal/device_list_update_sytest.go b/userapi/internal/device_list_update_sytest.go
new file mode 100644
index 00000000..1c60d2eb
--- /dev/null
+++ b/userapi/internal/device_list_update_sytest.go
@@ -0,0 +1,25 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build vw
+
+package internal
+
+import "time"
+
+// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite
+// results in a one-hour wait time from a previous device so the test times out. This is fine for
+// production, but makes an otherwise passing test fail.
+const defaultWaitTime = time.Second
+const hourWaitTime = time.Second
diff --git a/userapi/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go
new file mode 100644
index 00000000..868fc9be
--- /dev/null
+++ b/userapi/internal/device_list_update_test.go
@@ -0,0 +1,431 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ "context"
+ "crypto/ed25519"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "reflect"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/matrix-org/gomatrixserverlib"
+
+ roomserver "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage"
+)
+
+var (
+ ctx = context.Background()
+)
+
+type mockKeyChangeProducer struct {
+ events []api.DeviceMessage
+}
+
+func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error {
+ p.events = append(p.events, keys...)
+ return nil
+}
+
+type mockDeviceListUpdaterDatabase struct {
+ staleUsers map[string]bool
+ prevIDsExist func(string, []int64) bool
+ storedKeys []api.DeviceMessage
+ mu sync.Mutex // protect staleUsers
+}
+
+func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error {
+ return nil
+}
+
+// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
+// If no domains are given, all user IDs with stale device lists are returned.
+func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ var result []string
+ for userID, isStale := range d.staleUsers {
+ if !isStale {
+ continue
+ }
+ _, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return nil, err
+ }
+ if len(domains) == 0 {
+ result = append(result, userID)
+ continue
+ }
+ for _, d := range domains {
+ if remoteServer == d {
+ result = append(result, userID)
+ break
+ }
+ }
+ }
+ return result, nil
+}
+
+// MarkDeviceListStale sets the stale bit for this user to isStale.
+func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.staleUsers[userID] = isStale
+ return nil
+}
+
+func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.staleUsers[userID]
+}
+
+// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
+// for this (user, device). Does not modify the stream ID for keys.
+func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
+ d.storedKeys = append(d.storedKeys, keys...)
+ return nil
+}
+
+// PrevIDsExists returns true if all prev IDs exist for this user.
+func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
+ return d.prevIDsExist(userID, prevIDs), nil
+}
+
+func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
+ return nil
+}
+
+type mockDeviceListUpdaterAPI struct {
+}
+
+func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
+ return nil
+}
+
+type roundTripper struct {
+ fn func(*http.Request) (*http.Response, error)
+}
+
+func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ return t.fn(req)
+}
+
+func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
+ _, pkey, _ := ed25519.GenerateKey(nil)
+ fedClient := gomatrixserverlib.NewFederationClient(
+ []*gomatrixserverlib.SigningIdentity{
+ {
+ ServerName: gomatrixserverlib.ServerName("example.test"),
+ KeyID: gomatrixserverlib.KeyID("ed25519:test"),
+ PrivateKey: pkey,
+ },
+ },
+ )
+ fedClient.Client = *gomatrixserverlib.NewClient(
+ gomatrixserverlib.WithTransport(&roundTripper{tripper}),
+ )
+ return fedClient
+}
+
+// Test that the device keys get persisted and emitted if we have the previous IDs.
+func TestUpdateHavePrevID(t *testing.T) {
+ db := &mockDeviceListUpdaterDatabase{
+ staleUsers: make(map[string]bool),
+ prevIDsExist: func(string, []int64) bool {
+ return true
+ },
+ }
+ ap := &mockDeviceListUpdaterAPI{}
+ producer := &mockKeyChangeProducer{}
+ updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost")
+ event := gomatrixserverlib.DeviceListUpdateEvent{
+ DeviceDisplayName: "Foo Bar",
+ Deleted: false,
+ DeviceID: "FOO",
+ Keys: []byte(`{"key":"value"}`),
+ PrevID: []int64{0},
+ StreamID: 1,
+ UserID: "@alice:localhost",
+ }
+ err := updater.Update(ctx, event)
+ if err != nil {
+ t.Fatalf("Update returned an error: %s", err)
+ }
+ want := api.DeviceMessage{
+ Type: api.TypeDeviceKeyUpdate,
+ StreamID: event.StreamID,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: event.DeviceID,
+ DisplayName: event.DeviceDisplayName,
+ KeyJSON: event.Keys,
+ UserID: event.UserID,
+ },
+ }
+ if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
+ t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
+ }
+ if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
+ t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
+ }
+ if db.isStale(event.UserID) {
+ t.Errorf("%s incorrectly marked as stale", event.UserID)
+ }
+}
+
+// Test that device keys are fetched from the remote server if we are missing prev IDs
+// and that the user's devices are marked as stale until it succeeds.
+func TestUpdateNoPrevID(t *testing.T) {
+ db := &mockDeviceListUpdaterDatabase{
+ staleUsers: make(map[string]bool),
+ prevIDsExist: func(string, []int64) bool {
+ return false
+ },
+ }
+ ap := &mockDeviceListUpdaterAPI{}
+ producer := &mockKeyChangeProducer{}
+ remoteUserID := "@alice:example.somewhere"
+ var wg sync.WaitGroup
+ wg.Add(1)
+ keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
+ fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
+ defer wg.Done()
+ if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) {
+ return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
+ }
+ return &http.Response{
+ StatusCode: 200,
+ Body: io.NopCloser(strings.NewReader(`
+ {
+ "user_id": "` + remoteUserID + `",
+ "stream_id": 5,
+ "devices": [
+ {
+ "device_id": "JLAFKJWSCS",
+ "keys": ` + keyJSON + `,
+ "device_display_name": "Mobile Phone"
+ }
+ ]
+ }
+ `)),
+ }, nil
+ })
+ updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test")
+ if err := updater.Start(); err != nil {
+ t.Fatalf("failed to start updater: %s", err)
+ }
+ event := gomatrixserverlib.DeviceListUpdateEvent{
+ DeviceDisplayName: "Mobile Phone",
+ Deleted: false,
+ DeviceID: "another_device_id",
+ Keys: []byte(`{"key":"value"}`),
+ PrevID: []int64{3},
+ StreamID: 4,
+ UserID: remoteUserID,
+ }
+ err := updater.Update(ctx, event)
+
+ if err != nil {
+ t.Fatalf("Update returned an error: %s", err)
+ }
+ t.Log("waiting for /users/devices to be called...")
+ wg.Wait()
+ // wait a bit for db to be updated...
+ time.Sleep(100 * time.Millisecond)
+ want := api.DeviceMessage{
+ Type: api.TypeDeviceKeyUpdate,
+ StreamID: 5,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: "JLAFKJWSCS",
+ DisplayName: "Mobile Phone",
+ UserID: remoteUserID,
+ KeyJSON: []byte(keyJSON),
+ },
+ }
+ // Now we should have a fresh list and the keys and emitted something
+ if db.isStale(event.UserID) {
+ t.Errorf("%s still marked as stale", event.UserID)
+ }
+ if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
+ t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON))
+ t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
+ }
+ if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
+ t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
+ }
+
+}
+
+// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the
+// update is still ongoing.
+func TestDebounce(t *testing.T) {
+ t.Skipf("panic on closed channel on GHA")
+ db := &mockDeviceListUpdaterDatabase{
+ staleUsers: make(map[string]bool),
+ prevIDsExist: func(string, []int64) bool {
+ return true
+ },
+ }
+ ap := &mockDeviceListUpdaterAPI{}
+ producer := &mockKeyChangeProducer{}
+ fedCh := make(chan *http.Response, 1)
+ srv := gomatrixserverlib.ServerName("example.com")
+ userID := "@alice:example.com"
+ keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
+ incomingFedReq := make(chan struct{})
+ fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
+ if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) {
+ return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
+ }
+ close(incomingFedReq)
+ return <-fedCh, nil
+ })
+ updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost")
+ if err := updater.Start(); err != nil {
+ t.Fatalf("failed to start updater: %s", err)
+ }
+
+ // hit this 5 times
+ var wg sync.WaitGroup
+ wg.Add(5)
+ for i := 0; i < 5; i++ {
+ go func() {
+ defer wg.Done()
+ if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil {
+ t.Errorf("ManualUpdate: %s", err)
+ }
+ }()
+ }
+
+ // wait until the updater hits federation
+ select {
+ case <-incomingFedReq:
+ case <-time.After(time.Second):
+ t.Fatalf("timed out waiting for updater to hit federation")
+ }
+
+ // user should be marked as stale
+ if !db.isStale(userID) {
+ t.Errorf("user %s not marked as stale", userID)
+ }
+ // now send the response over federation
+ fedCh <- &http.Response{
+ StatusCode: 200,
+ Body: io.NopCloser(strings.NewReader(`
+ {
+ "user_id": "` + userID + `",
+ "stream_id": 5,
+ "devices": [
+ {
+ "device_id": "JLAFKJWSCS",
+ "keys": ` + keyJSON + `,
+ "device_display_name": "Mobile Phone"
+ }
+ ]
+ }
+ `)),
+ }
+ close(fedCh)
+ // wait until all 5 ManualUpdates return. If we hit federation again we won't send a response
+ // and should panic with read on a closed channel
+ wg.Wait()
+
+ // user is no longer stale now
+ if db.isStale(userID) {
+ t.Errorf("user %s is marked as stale", userID)
+ }
+}
+
+func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
+ t.Helper()
+
+ base, _, _ := testrig.Base(nil)
+ connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
+ db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ return db, clearDB
+}
+
+type mockKeyserverRoomserverAPI struct {
+ leftUsers []string
+}
+
+func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
+ res.LeftUsers = m.leftUsers
+ return nil
+}
+
+func TestDeviceListUpdater_CleanUp(t *testing.T) {
+ processCtx := process.NewProcessContext()
+
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+
+ // Bob is not joined to any of our rooms
+ rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}}
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clearDB := mustCreateKeyserverDB(t, dbType)
+ defer clearDB()
+
+ // This should not get deleted
+ if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil {
+ t.Error(err)
+ }
+
+ // this one should get deleted
+ if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil {
+ t.Error(err)
+ }
+
+ updater := NewDeviceListUpdater(processCtx, db, nil,
+ nil, nil,
+ 0, rsAPI, "test")
+ if err := updater.CleanUp(); err != nil {
+ t.Error(err)
+ }
+
+ // check that we still have Alice in our stale list
+ staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
+ if err != nil {
+ t.Error(err)
+ }
+
+ // There should only be Alice
+ wantCount := 1
+ if count := len(staleUsers); count != wantCount {
+ t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count)
+ }
+
+ if staleUsers[0] != alice.ID {
+ t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID)
+ }
+ })
+}
diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go
new file mode 100644
index 00000000..be816fe5
--- /dev/null
+++ b/userapi/internal/key_api.go
@@ -0,0 +1,798 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+
+ "github.com/matrix-org/dendrite/userapi/api"
+)
+
+func (a *UserInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
+ userIDs, latest, err := a.KeyDatabase.KeyChanges(ctx, req.Offset, req.ToOffset)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: err.Error(),
+ }
+ return nil
+ }
+ res.Offset = latest
+ res.UserIDs = userIDs
+ return nil
+}
+
+func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
+ res.KeyErrors = make(map[string]map[string]*api.KeyError)
+ if len(req.DeviceKeys) > 0 {
+ a.uploadLocalDeviceKeys(ctx, req, res)
+ }
+ if len(req.OneTimeKeys) > 0 {
+ a.uploadOneTimeKeys(ctx, req, res)
+ }
+ otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
+ if err != nil {
+ return err
+ }
+ res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
+ return nil
+}
+
+func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
+ res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
+ res.Failures = make(map[string]interface{})
+ // wrap request map in a top-level by-domain map
+ domainToDeviceKeys := make(map[string]map[string]map[string]string)
+ for userID, val := range req.OneTimeKeys {
+ _, serverName, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ continue // ignore invalid users
+ }
+ nested, ok := domainToDeviceKeys[string(serverName)]
+ if !ok {
+ nested = make(map[string]map[string]string)
+ }
+ nested[userID] = val
+ domainToDeviceKeys[string(serverName)] = nested
+ }
+ for domain, local := range domainToDeviceKeys {
+ if !a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
+ continue
+ }
+ // claim local keys
+ keys, err := a.KeyDatabase.ClaimKeys(ctx, local)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
+ }
+ }
+ util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys")
+ for _, key := range keys {
+ _, ok := res.OneTimeKeys[key.UserID]
+ if !ok {
+ res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage)
+ }
+ _, ok = res.OneTimeKeys[key.UserID][key.DeviceID]
+ if !ok {
+ res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
+ }
+ for keyID, keyJSON := range key.KeyJSON {
+ res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
+ }
+ }
+ delete(domainToDeviceKeys, domain)
+ }
+ if len(domainToDeviceKeys) > 0 {
+ a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) claimRemoteKeys(
+ ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
+) {
+ var wg sync.WaitGroup // Wait for fan-out goroutines to finish
+ var mu sync.Mutex // Protects the response struct
+ var claimed int // Number of keys claimed in total
+ var failures int // Number of servers we failed to ask
+
+ util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys))
+ wg.Add(len(domainToDeviceKeys))
+
+ for d, k := range domainToDeviceKeys {
+ go func(domain string, keysToClaim map[string]map[string]string) {
+ fedCtx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+ defer wg.Done()
+
+ claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
+
+ mu.Lock()
+ defer mu.Unlock()
+
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
+ res.Failures[domain] = map[string]interface{}{
+ "message": err.Error(),
+ }
+ failures++
+ return
+ }
+
+ for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys {
+ res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
+ for deviceID, keys := range deviceIDToKeys {
+ res.OneTimeKeys[userID][deviceID] = keys
+ claimed += len(keys)
+ }
+ }
+ }(d, k)
+ }
+
+ wg.Wait()
+ util.GetLogger(ctx).WithFields(logrus.Fields{
+ "num_keys": claimed,
+ "num_failures": failures,
+ }).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys))
+}
+
+func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
+ if err := a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("Failed to delete device keys: %s", err),
+ }
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
+ count, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
+ }
+ return nil
+ }
+ res.Count = *count
+ return nil
+}
+
+func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
+ msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
+ }
+ return nil
+ }
+ maxStreamID := int64(0)
+ // remove deleted devices
+ var result []api.DeviceMessage
+ for _, m := range msgs {
+ if m.StreamID > maxStreamID {
+ maxStreamID = m.StreamID
+ }
+ if m.KeyJSON == nil || len(m.KeyJSON) == 0 {
+ continue
+ }
+ result = append(result, m)
+ }
+ res.Devices = result
+ res.StreamID = maxStreamID
+ return nil
+}
+
+// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present
+// in our database.
+func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error {
+ knownDevices, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, []string{}, true)
+ if err != nil {
+ return err
+ }
+ if len(knownDevices) == 0 {
+ return nil // fmt.Errorf("unknown user %s", req.UserID)
+ }
+
+ for i := range knownDevices {
+ if knownDevices[i].DeviceID == req.DeviceID {
+ return nil // we already know about this device
+ }
+ }
+
+ return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID)
+}
+
+// nolint:gocyclo
+func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
+ var respMu sync.Mutex
+ res.DeviceKeys = make(map[string]map[string]json.RawMessage)
+ res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
+ res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
+ res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
+ res.Failures = make(map[string]interface{})
+
+ // make a map from domain to device keys
+ domainToDeviceKeys := make(map[string]map[string][]string)
+ domainToCrossSigningKeys := make(map[string]map[string]struct{})
+ for userID, deviceIDs := range req.UserToDevices {
+ _, serverName, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ continue // ignore invalid users
+ }
+ domain := string(serverName)
+ // query local devices
+ if a.Config.Matrix.IsLocalServerName(serverName) {
+ deviceKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to query local device keys: %s", err),
+ }
+ return nil
+ }
+
+ // pull out display names after we have the keys so we handle wildcards correctly
+ var dids []string
+ for _, dk := range deviceKeys {
+ dids = append(dids, dk.DeviceID)
+ }
+ var queryRes api.QueryDeviceInfosResponse
+ err = a.QueryDeviceInfos(ctx, &api.QueryDeviceInfosRequest{
+ DeviceIDs: dids,
+ }, &queryRes)
+ if err != nil {
+ util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
+ }
+
+ if res.DeviceKeys[userID] == nil {
+ res.DeviceKeys[userID] = make(map[string]json.RawMessage)
+ }
+ for _, dk := range deviceKeys {
+ if len(dk.KeyJSON) == 0 {
+ continue // don't include blank keys
+ }
+ // inject display name if known (either locally or remotely)
+ displayName := dk.DisplayName
+ if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
+ displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
+ }
+ dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
+ DisplayName string `json:"device_display_name,omitempty"`
+ }{displayName})
+ res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
+ }
+ } else {
+ domainToDeviceKeys[domain] = make(map[string][]string)
+ domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
+ }
+ // work out if our cross-signing request for this user was
+ // satisfied, if not add them to the list of things to fetch
+ if _, ok := res.MasterKeys[userID]; !ok {
+ if _, ok := domainToCrossSigningKeys[domain]; !ok {
+ domainToCrossSigningKeys[domain] = make(map[string]struct{})
+ }
+ domainToCrossSigningKeys[domain][userID] = struct{}{}
+ }
+ if _, ok := res.SelfSigningKeys[userID]; !ok {
+ if _, ok := domainToCrossSigningKeys[domain]; !ok {
+ domainToCrossSigningKeys[domain] = make(map[string]struct{})
+ }
+ domainToCrossSigningKeys[domain][userID] = struct{}{}
+ }
+ }
+
+ // attempt to satisfy key queries from the local database first as we should get device updates pushed to us
+ domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
+ if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
+ // perform key queries for remote devices
+ a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
+ }
+
+ // Now that we've done the potentially expensive work of asking the federation,
+ // try filling the cross-signing keys from the database that we know about.
+ a.crossSigningKeysFromDatabase(ctx, req, res)
+
+ // Finally, append signatures that we know about
+ // TODO: This is horrible because we need to round-trip the signature from
+ // JSON, add the signatures and marshal it again, for some reason?
+
+ for targetUserID, masterKey := range res.MasterKeys {
+ if masterKey.Signatures == nil {
+ masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ for targetKeyID := range masterKey.Keys {
+ sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
+ if err != nil {
+ // Stop executing the function if the context was canceled/the deadline was exceeded,
+ // as we can't continue without a valid context.
+ if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+ return nil
+ }
+ logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
+ continue
+ }
+ if len(sigMap) == 0 {
+ continue
+ }
+ for sourceUserID, forSourceUser := range sigMap {
+ for sourceKeyID, sourceSig := range forSourceUser {
+ if _, ok := masterKey.Signatures[sourceUserID]; !ok {
+ masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
+ }
+ }
+ }
+ }
+
+ for targetUserID, forUserID := range res.DeviceKeys {
+ for targetKeyID, key := range forUserID {
+ sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
+ if err != nil {
+ // Stop executing the function if the context was canceled/the deadline was exceeded,
+ // as we can't continue without a valid context.
+ if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+ return nil
+ }
+ logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
+ continue
+ }
+ if len(sigMap) == 0 {
+ continue
+ }
+ var deviceKey gomatrixserverlib.DeviceKeys
+ if err = json.Unmarshal(key, &deviceKey); err != nil {
+ continue
+ }
+ if deviceKey.Signatures == nil {
+ deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ for sourceUserID, forSourceUser := range sigMap {
+ for sourceKeyID, sourceSig := range forSourceUser {
+ if _, ok := deviceKey.Signatures[sourceUserID]; !ok {
+ deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
+ }
+ }
+ if js, err := json.Marshal(deviceKey); err == nil {
+ res.DeviceKeys[targetUserID][targetKeyID] = js
+ }
+ }
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) remoteKeysFromDatabase(
+ ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
+) map[string]map[string][]string {
+ fetchRemote := make(map[string]map[string][]string)
+ for domain, userToDeviceMap := range domainToDeviceKeys {
+ for userID, deviceIDs := range userToDeviceMap {
+ // we can't safely return keys from the db when all devices are requested as we don't
+ // know if one has just been added.
+ if len(deviceIDs) > 0 {
+ err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
+ if err == nil {
+ continue
+ }
+ util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase")
+ }
+ // fetch device lists from remote
+ if _, ok := fetchRemote[domain]; !ok {
+ fetchRemote[domain] = make(map[string][]string)
+ }
+ fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
+
+ }
+ }
+ return fetchRemote
+}
+
+func (a *UserInternalAPI) queryRemoteKeys(
+ ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse,
+ domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{},
+) {
+ resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys))
+ // allows us to wait until all federation servers have been poked
+ var wg sync.WaitGroup
+ // mutex for writing directly to res (e.g failures)
+ var respMu sync.Mutex
+
+ domains := map[string]struct{}{}
+ for domain := range domainToDeviceKeys {
+ if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
+ continue
+ }
+ domains[domain] = struct{}{}
+ }
+ for domain := range domainToCrossSigningKeys {
+ if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
+ continue
+ }
+ domains[domain] = struct{}{}
+ }
+ wg.Add(len(domains))
+
+ // fan out
+ for domain := range domains {
+ go a.queryRemoteKeysOnServer(
+ ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain],
+ &wg, &respMu, timeout, resultCh, res,
+ )
+ }
+
+ // Close the result channel when the goroutines have quit so the for .. range exits
+ go func() {
+ wg.Wait()
+ close(resultCh)
+ }()
+
+ processResult := func(result *gomatrixserverlib.RespQueryKeys) {
+ respMu.Lock()
+ defer respMu.Unlock()
+ for userID, nest := range result.DeviceKeys {
+ res.DeviceKeys[userID] = make(map[string]json.RawMessage)
+ for deviceID, deviceKey := range nest {
+ keyJSON, err := json.Marshal(deviceKey)
+ if err != nil {
+ continue
+ }
+ res.DeviceKeys[userID][deviceID] = keyJSON
+ }
+ }
+
+ for userID, body := range result.MasterKeys {
+ res.MasterKeys[userID] = body
+ }
+
+ for userID, body := range result.SelfSigningKeys {
+ res.SelfSigningKeys[userID] = body
+ }
+
+ // TODO: do we want to persist these somewhere now
+ // that we have fetched them?
+ }
+
+ for result := range resultCh {
+ processResult(result)
+ }
+}
+
+func (a *UserInternalAPI) queryRemoteKeysOnServer(
+ ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{},
+ wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
+ res *api.QueryKeysResponse,
+) {
+ defer wg.Done()
+ fedCtx := ctx
+ if timeout > 0 {
+ var cancel context.CancelFunc
+ fedCtx, cancel = context.WithTimeout(ctx, timeout)
+ defer cancel()
+ }
+ // for users who we do not have any knowledge about, try to start doing device list updates for them
+ // by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
+ // lack a stream ID.
+ userIDsForAllDevices := map[string]struct{}{}
+ for userID, deviceIDs := range devKeys {
+ if len(deviceIDs) == 0 {
+ userIDsForAllDevices[userID] = struct{}{}
+ }
+ }
+ // for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing
+ // a device list update, so we'll populate those back into the /keys/query list if not
+ for userID := range crossSigningKeys {
+ if devKeys == nil {
+ devKeys = map[string][]string{}
+ }
+ if _, ok := userIDsForAllDevices[userID]; !ok {
+ devKeys[userID] = []string{}
+ }
+ }
+ for userID := range userIDsForAllDevices {
+ err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
+ if err != nil {
+ logrus.WithFields(logrus.Fields{
+ logrus.ErrorKey: err,
+ "user_id": userID,
+ "server": serverName,
+ }).Error("Failed to manually update device lists for user")
+ // try to do it via /keys/query
+ devKeys[userID] = []string{}
+ continue
+ }
+ // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
+ // user so the fact that we're populating all devices here isn't a problem so long as we have devices.
+ err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
+ if err != nil {
+ logrus.WithFields(logrus.Fields{
+ logrus.ErrorKey: err,
+ "user_id": userID,
+ "server": serverName,
+ }).Error("Failed to manually update device lists for user")
+ // try to do it via /keys/query
+ devKeys[userID] = []string{}
+ continue
+ }
+ }
+ if len(devKeys) == 0 {
+ return
+ }
+ queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
+ if err == nil {
+ resultCh <- &queryKeysResp
+ return
+ }
+ respMu.Lock()
+ res.Failures[serverName] = map[string]interface{}{
+ "message": err.Error(),
+ }
+ respMu.Unlock()
+
+ // last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server
+ // is down, better to return something than nothing at all. Clients can know about the failure by
+ // inspecting the failures map though so they can know it's a cached response.
+ for userID, dkeys := range devKeys {
+ // drop the error as it's already a failure at this point
+ _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
+ }
+
+ // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
+ respMu.Lock()
+ if len(res.DeviceKeys) > 0 {
+ delete(res.Failures, serverName)
+ }
+ respMu.Unlock()
+}
+
+func (a *UserInternalAPI) populateResponseWithDeviceKeysFromDatabase(
+ ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
+) error {
+ keys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
+ // if we can't query the db or there are fewer keys than requested, fetch from remote.
+ if err != nil {
+ return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
+ }
+ if len(keys) < len(deviceIDs) {
+ return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID)
+ }
+ if len(deviceIDs) == 0 && len(keys) == 0 {
+ return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
+ }
+ respMu.Lock()
+ if res.DeviceKeys[userID] == nil {
+ res.DeviceKeys[userID] = make(map[string]json.RawMessage)
+ }
+ respMu.Unlock()
+
+ for _, key := range keys {
+ if len(key.KeyJSON) == 0 {
+ continue // ignore deleted keys
+ }
+ // inject the display name
+ key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
+ DisplayName string `json:"device_display_name,omitempty"`
+ }{key.DisplayName})
+ respMu.Lock()
+ res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
+ respMu.Unlock()
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
+ // get a list of devices from the user API that actually exist, as
+ // we won't store keys for devices that don't exist
+ uapidevices := &api.QueryDevicesResponse{}
+ if err := a.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
+ res.Error = &api.KeyError{
+ Err: err.Error(),
+ }
+ return
+ }
+ if !uapidevices.UserExists {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("user %q does not exist", req.UserID),
+ }
+ return
+ }
+ existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
+ for _, key := range uapidevices.Devices {
+ existingDeviceMap[key.ID] = struct{}{}
+ }
+
+ // Get all of the user existing device keys so we can check for changes.
+ existingKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, true)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
+ }
+ return
+ }
+
+ // Work out whether we have device keys in the keyserver for devices that
+ // no longer exist in the user API. This is mostly an exercise to ensure
+ // that we keep some integrity between the two.
+ var toClean []gomatrixserverlib.KeyID
+ for _, k := range existingKeys {
+ if _, ok := existingDeviceMap[k.DeviceID]; !ok {
+ toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
+ }
+ }
+
+ if len(toClean) > 0 {
+ if err = a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
+ logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
+ } else {
+ logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
+ }
+ }
+
+ var keysToStore []api.DeviceMessage
+
+ if req.OnlyDisplayNameUpdates {
+ for _, existingKey := range existingKeys {
+ for _, newKey := range req.DeviceKeys {
+ switch {
+ case existingKey.UserID != newKey.UserID:
+ continue
+ case existingKey.DeviceID != newKey.DeviceID:
+ continue
+ case existingKey.DisplayName != newKey.DisplayName:
+ existingKey.DisplayName = newKey.DisplayName
+ }
+ }
+ keysToStore = append(keysToStore, existingKey)
+ }
+ } else {
+ // assert that the user ID / device ID are not lying for each key
+ for _, key := range req.DeviceKeys {
+ var serverName gomatrixserverlib.ServerName
+ _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
+ if err != nil {
+ continue // ignore invalid users
+ }
+ if !a.Config.Matrix.IsLocalServerName(serverName) {
+ continue // ignore remote users
+ }
+ if len(key.KeyJSON) == 0 {
+ keysToStore = append(keysToStore, key.WithStreamID(0))
+ continue // deleted keys don't need sanity checking
+ }
+ // check that the device in question actually exists in the user
+ // API before we try and store a key for it
+ if _, ok := existingDeviceMap[key.DeviceID]; !ok {
+ continue
+ }
+ gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
+ gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
+ if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
+ keysToStore = append(keysToStore, key.WithStreamID(0))
+ continue
+ }
+
+ res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
+ Err: fmt.Sprintf(
+ "user_id or device_id mismatch: users: %s - %s, devices: %s - %s",
+ gotUserID, key.UserID, gotDeviceID, key.DeviceID,
+ ),
+ })
+ }
+ }
+
+ // store the device keys and emit changes
+ err = a.KeyDatabase.StoreLocalDeviceKeys(ctx, keysToStore)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
+ }
+ return
+ }
+ err = emitDeviceKeyChanges(a.KeyChangeProducer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates)
+ if err != nil {
+ util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
+ }
+}
+
+func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
+ if req.UserID == "" {
+ res.Error = &api.KeyError{
+ Err: "user ID missing",
+ }
+ }
+ if req.DeviceID != "" && len(req.OneTimeKeys) == 0 {
+ counts, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("a.KeyDatabase.OneTimeKeysCount: %s", err),
+ }
+ }
+ if counts != nil {
+ res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
+ }
+ return
+ }
+ for _, key := range req.OneTimeKeys {
+ // grab existing keys based on (user/device/algorithm/key ID)
+ keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
+ i := 0
+ for keyIDWithAlgo := range key.KeyJSON {
+ keyIDsWithAlgorithms[i] = keyIDWithAlgo
+ i++
+ }
+ existingKeys, err := a.KeyDatabase.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
+ if err != nil {
+ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
+ Err: "failed to query existing one-time keys: " + err.Error(),
+ })
+ continue
+ }
+ for keyIDWithAlgo := range existingKeys {
+ // if keys exist and the JSON doesn't match, error out as the key already exists
+ if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
+ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
+ Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
+ })
+ continue
+ }
+ }
+ // store one-time keys
+ counts, err := a.KeyDatabase.StoreOneTimeKeys(ctx, key)
+ if err != nil {
+ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
+ Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()),
+ })
+ continue
+ }
+ // collect counts
+ res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
+ }
+
+}
+
+func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
+ // if we only want to update the display names, we can skip the checks below
+ if onlyUpdateDisplayName {
+ return producer.ProduceKeyChanges(new)
+ }
+ // find keys in new that are not in existing
+ var keysAdded []api.DeviceMessage
+ for _, newKey := range new {
+ exists := false
+ for _, existingKey := range existing {
+ // Do not treat the absence of keys as equal, or else we will not emit key changes
+ // when users delete devices which never had a key to begin with as both KeyJSONs are nil.
+ if existingKey.DeviceKeysEqual(&newKey) {
+ exists = true
+ break
+ }
+ }
+ if !exists {
+ keysAdded = append(keysAdded, newKey)
+ }
+ }
+ return producer.ProduceKeyChanges(keysAdded)
+}
diff --git a/userapi/internal/key_api_test.go b/userapi/internal/key_api_test.go
new file mode 100644
index 00000000..fc7e7e0d
--- /dev/null
+++ b/userapi/internal/key_api_test.go
@@ -0,0 +1,161 @@
+package internal_test
+
+import (
+ "context"
+ "reflect"
+ "testing"
+
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/internal"
+ "github.com/matrix-org/dendrite/userapi/storage"
+)
+
+func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
+ t.Helper()
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ base, _, _ := testrig.Base(nil)
+ db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ })
+ if err != nil {
+ t.Fatalf("failed to create new user db: %v", err)
+ }
+ return db, func() {
+ base.Close()
+ close()
+ }
+}
+
+func Test_QueryDeviceMessages(t *testing.T) {
+ alice := test.NewUser(t)
+ type args struct {
+ req *api.QueryDeviceMessagesRequest
+ res *api.QueryDeviceMessagesResponse
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr bool
+ want *api.QueryDeviceMessagesResponse
+ }{
+ {
+ name: "no existing keys",
+ args: args{
+ req: &api.QueryDeviceMessagesRequest{
+ UserID: "@doesNotExist:localhost",
+ },
+ res: &api.QueryDeviceMessagesResponse{},
+ },
+ want: &api.QueryDeviceMessagesResponse{},
+ },
+ {
+ name: "existing user returns devices",
+ args: args{
+ req: &api.QueryDeviceMessagesRequest{
+ UserID: alice.ID,
+ },
+ res: &api.QueryDeviceMessagesResponse{},
+ },
+ want: &api.QueryDeviceMessagesResponse{
+ StreamID: 6,
+ Devices: []api.DeviceMessage{
+ {
+ Type: api.TypeDeviceKeyUpdate, StreamID: 5, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ DisplayName: "first device",
+ UserID: alice.ID,
+ KeyJSON: []byte("ghi"),
+ },
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, StreamID: 6, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "mySecondDevice",
+ DisplayName: "second device",
+ UserID: alice.ID,
+ KeyJSON: []byte("jkl"),
+ }, // streamID 6
+ },
+ },
+ },
+ },
+ }
+
+ deviceMessages := []api.DeviceMessage{
+ { // not the user we're looking for
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ UserID: "@doesNotExist:localhost",
+ },
+ // streamID 1 for this user
+ },
+ { // empty keyJSON will be ignored
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ UserID: alice.ID,
+ }, // streamID 1
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ UserID: alice.ID,
+ KeyJSON: []byte("abc"),
+ }, // streamID 2
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ UserID: alice.ID,
+ KeyJSON: []byte("def"),
+ }, // streamID 3
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ UserID: alice.ID,
+ KeyJSON: []byte(""),
+ }, // streamID 4
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ DisplayName: "first device",
+ UserID: alice.ID,
+ KeyJSON: []byte("ghi"),
+ }, // streamID 5
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "mySecondDevice",
+ UserID: alice.ID,
+ KeyJSON: []byte("jkl"),
+ DisplayName: "second device",
+ }, // streamID 6
+ },
+ }
+ ctx := context.Background()
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, closeDB := mustCreateDatabase(t, dbType)
+ defer closeDB()
+ if err := db.StoreLocalDeviceKeys(ctx, deviceMessages); err != nil {
+ t.Fatalf("failed to store local devicesKeys")
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := &internal.UserInternalAPI{
+ KeyDatabase: db,
+ }
+ if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr {
+ t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ got := tt.args.res
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("QueryDeviceMessages(): got:\n%+v, want:\n%+v", got, tt.want)
+ }
+ })
+ }
+ })
+}
diff --git a/userapi/internal/api.go b/userapi/internal/user_api.go
index 0bb480da..1cbd9719 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/user_api.go
@@ -23,6 +23,7 @@ import (
"strconv"
"time"
+ fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@@ -32,7 +33,6 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/sqlutil"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
synctypes "github.com/matrix-org/dendrite/syncapi/types"
@@ -44,17 +44,19 @@ import (
)
type UserInternalAPI struct {
- DB storage.Database
- SyncProducer *producers.SyncAPI
- Config *config.UserAPI
+ DB storage.UserDatabase
+ KeyDatabase storage.KeyDatabase
+ SyncProducer *producers.SyncAPI
+ KeyChangeProducer *producers.KeyChange
+ Config *config.UserAPI
DisableTLSValidation bool
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
- KeyAPI keyapi.UserKeyAPI
RSAPI rsapi.UserRoomserverAPI
PgClient pushgateway.Client
- Cfg *config.UserAPI
+ FedClient fedsenderapi.KeyserverFederationAPI
+ Updater *DeviceListUpdater
}
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
@@ -221,7 +223,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return fmt.Errorf("a.DB.SetDisplayName: %w", err)
}
- postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
+ postRegisterJoinRooms(a.Config, acc, a.RSAPI)
res.AccountCreated = true
res.Account = acc
@@ -293,14 +295,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
return err
}
// Ask the keyserver to delete device keys and signatures for those devices
- deleteReq := &keyapi.PerformDeleteKeysRequest{
+ deleteReq := &api.PerformDeleteKeysRequest{
UserID: req.UserID,
}
for _, keyID := range req.DeviceIDs {
deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID))
}
- deleteRes := &keyapi.PerformDeleteKeysResponse{}
- if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
+ deleteRes := &api.PerformDeleteKeysResponse{}
+ if err := a.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
return err
}
if err := deleteRes.Error; err != nil {
@@ -311,17 +313,17 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
}
func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error {
- deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs))
+ deviceKeys := make([]api.DeviceKeys, len(deviceIDs))
for i, did := range deviceIDs {
- deviceKeys[i] = keyapi.DeviceKeys{
+ deviceKeys[i] = api.DeviceKeys{
UserID: userID,
DeviceID: did,
KeyJSON: nil,
}
}
- var uploadRes keyapi.PerformUploadKeysResponse
- if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
+ var uploadRes api.PerformUploadKeysResponse
+ if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{
UserID: userID,
DeviceKeys: deviceKeys,
}, &uploadRes); err != nil {
@@ -385,10 +387,10 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
}
if req.DisplayName != nil && dev.DisplayName != *req.DisplayName {
// display name has changed: update the device key
- var uploadRes keyapi.PerformUploadKeysResponse
- if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
+ var uploadRes api.PerformUploadKeysResponse
+ if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{
UserID: req.RequestingUserID,
- DeviceKeys: []keyapi.DeviceKeys{
+ DeviceKeys: []api.DeviceKeys{
{
DeviceID: dev.ID,
DisplayName: *req.DisplayName,
diff --git a/userapi/producers/keychange.go b/userapi/producers/keychange.go
new file mode 100644
index 00000000..da6cea31
--- /dev/null
+++ b/userapi/producers/keychange.go
@@ -0,0 +1,107 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package producers
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage"
+ "github.com/nats-io/nats.go"
+ "github.com/sirupsen/logrus"
+)
+
+// KeyChange produces key change events for the sync API and federation sender to consume
+type KeyChange struct {
+ Topic string
+ JetStream JetStreamPublisher
+ DB storage.KeyChangeDatabase
+}
+
+// ProduceKeyChanges creates new change events for each key
+func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
+ userToDeviceCount := make(map[string]int)
+ for _, key := range keys {
+ id, err := p.DB.StoreKeyChange(context.Background(), key.UserID)
+ if err != nil {
+ return err
+ }
+ key.DeviceChangeID = id
+ value, err := json.Marshal(key)
+ if err != nil {
+ return err
+ }
+
+ m := &nats.Msg{
+ Subject: p.Topic,
+ Header: nats.Header{},
+ }
+ m.Header.Set(jetstream.UserID, key.UserID)
+ m.Data = value
+
+ _, err = p.JetStream.PublishMsg(m)
+ if err != nil {
+ return err
+ }
+
+ userToDeviceCount[key.UserID]++
+ }
+ for userID, count := range userToDeviceCount {
+ logrus.WithFields(logrus.Fields{
+ "user_id": userID,
+ "num_key_changes": count,
+ }).Tracef("Produced to key change topic '%s'", p.Topic)
+ }
+ return nil
+}
+
+func (p *KeyChange) ProduceSigningKeyUpdate(key api.CrossSigningKeyUpdate) error {
+ output := &api.DeviceMessage{
+ Type: api.TypeCrossSigningUpdate,
+ OutputCrossSigningKeyUpdate: &api.OutputCrossSigningKeyUpdate{
+ CrossSigningKeyUpdate: key,
+ },
+ }
+
+ id, err := p.DB.StoreKeyChange(context.Background(), key.UserID)
+ if err != nil {
+ return err
+ }
+ output.DeviceChangeID = id
+
+ value, err := json.Marshal(output)
+ if err != nil {
+ return err
+ }
+
+ m := &nats.Msg{
+ Subject: p.Topic,
+ Header: nats.Header{},
+ }
+ m.Header.Set(jetstream.UserID, key.UserID)
+ m.Data = value
+
+ _, err = p.JetStream.PublishMsg(m)
+ if err != nil {
+ return err
+ }
+
+ logrus.WithFields(logrus.Fields{
+ "user_id": key.UserID,
+ }).Tracef("Produced to cross-signing update topic '%s'", p.Topic)
+ return nil
+}
diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go
index 51eaa985..165de899 100644
--- a/userapi/producers/syncapi.go
+++ b/userapi/producers/syncapi.go
@@ -19,13 +19,13 @@ type JetStreamPublisher interface {
// SyncAPI produces messages for the Sync API server to consume.
type SyncAPI struct {
- db storage.Database
+ db storage.Notification
producer JetStreamPublisher
clientDataTopic string
notificationDataTopic string
}
-func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
+func NewSyncAPI(db storage.UserDatabase, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
return &SyncAPI{
db: db,
producer: js,
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go
index c22b7658..27837886 100644
--- a/userapi/storage/interface.go
+++ b/userapi/storage/interface.go
@@ -90,7 +90,7 @@ type KeyBackup interface {
type LoginToken interface {
// CreateLoginToken generates a token, stores and returns it. The lifetime is
- // determined by the loginTokenLifetime given to the Database constructor.
+ // determined by the loginTokenLifetime given to the UserDatabase constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
@@ -130,7 +130,7 @@ type Notification interface {
DeleteOldNotifications(ctx context.Context) error
}
-type Database interface {
+type UserDatabase interface {
Account
AccountData
Device
@@ -144,6 +144,78 @@ type Database interface {
ThreePID
}
+type KeyChangeDatabase interface {
+ // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
+ // `userID` is the the user who has changed their keys in some way.
+ StoreKeyChange(ctx context.Context, userID string) (int64, error)
+}
+
+type KeyDatabase interface {
+ KeyChangeDatabase
+ // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
+ // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
+ ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
+
+ // StoreOneTimeKeys persists the given one-time keys.
+ StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
+
+ // OneTimeKeysCount returns a count of all OTKs for this device.
+ OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
+
+ // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
+ DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
+
+ // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
+ // for this (user, device).
+ // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
+ // Returns an error if there was a problem storing the keys.
+ StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
+
+ // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
+ // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
+ // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
+ StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
+
+ // PrevIDsExists returns true if all prev IDs exist for this user.
+ PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
+
+ // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
+ // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
+ DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
+
+ // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
+ // cross-signing signatures relating to that device.
+ DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error
+
+ // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
+ // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
+ ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
+
+ // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
+ // A to offset of types.OffsetNewest means no upper limit.
+ // Returns the offset of the latest key change.
+ KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
+
+ // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
+ // If no domains are given, all user IDs with stale device lists are returned.
+ StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
+
+ // MarkDeviceListStale sets the stale bit for this user to isStale.
+ MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
+
+ CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error)
+ CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error)
+ CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
+
+ StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
+ StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
+
+ DeleteStaleDeviceLists(
+ ctx context.Context,
+ userIDs []string,
+ ) error
+}
+
type Statistics interface {
UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error)
DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error)
diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go
index 2a4777d7..05716037 100644
--- a/userapi/storage/postgres/account_data_table.go
+++ b/userapi/storage/postgres/account_data_table.go
@@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData(
roomID, dataType string, content json.RawMessage,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
- _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
+ // Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage
+ // when passing the data to trigger "NOT NULL" constraint
+ var data *json.RawMessage
+ if len(content) > 0 {
+ data = &content
+ }
+ _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data)
return
}
diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go
new file mode 100644
index 00000000..c0ecbd30
--- /dev/null
+++ b/userapi/storage/postgres/cross_signing_keys_table.go
@@ -0,0 +1,102 @@
+// Copyright 2021 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var crossSigningKeysSchema = `
+CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
+ user_id TEXT NOT NULL,
+ key_type SMALLINT NOT NULL,
+ key_data TEXT NOT NULL,
+ PRIMARY KEY (user_id, key_type)
+);
+`
+
+const selectCrossSigningKeysForUserSQL = "" +
+ "SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
+ " WHERE user_id = $1"
+
+const upsertCrossSigningKeysForUserSQL = "" +
+ "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
+ " VALUES($1, $2, $3)" +
+ " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3"
+
+type crossSigningKeysStatements struct {
+ db *sql.DB
+ selectCrossSigningKeysForUserStmt *sql.Stmt
+ upsertCrossSigningKeysForUserStmt *sql.Stmt
+}
+
+func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
+ s := &crossSigningKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(crossSigningKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
+ {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
+ }.Prepare(db)
+}
+
+func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
+ ctx context.Context, txn *sql.Tx, userID string,
+) (r types.CrossSigningKeyMap, err error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
+ r = types.CrossSigningKeyMap{}
+ for rows.Next() {
+ var keyTypeInt int16
+ var keyData gomatrixserverlib.Base64Bytes
+ if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
+ return nil, err
+ }
+ keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
+ if !ok {
+ return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
+ }
+ r[keyType] = keyData
+ }
+ return
+}
+
+func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
+ ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
+) error {
+ keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
+ if !ok {
+ return fmt.Errorf("unknown key purpose %q", keyType)
+ }
+ if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
+ return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/postgres/cross_signing_sigs_table.go b/userapi/storage/postgres/cross_signing_sigs_table.go
new file mode 100644
index 00000000..b0117145
--- /dev/null
+++ b/userapi/storage/postgres/cross_signing_sigs_table.go
@@ -0,0 +1,131 @@
+// Copyright 2021 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var crossSigningSigsSchema = `
+CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
+ origin_user_id TEXT NOT NULL,
+ origin_key_id TEXT NOT NULL,
+ target_user_id TEXT NOT NULL,
+ target_key_id TEXT NOT NULL,
+ signature TEXT NOT NULL,
+ PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
+`
+
+const selectCrossSigningSigsForTargetSQL = "" +
+ "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
+ " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3"
+
+const upsertCrossSigningSigsForTargetSQL = "" +
+ "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
+ " VALUES($1, $2, $3, $4, $5)" +
+ " ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5"
+
+const deleteCrossSigningSigsForTargetSQL = "" +
+ "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
+
+type crossSigningSigsStatements struct {
+ db *sql.DB
+ selectCrossSigningSigsForTargetStmt *sql.Stmt
+ upsertCrossSigningSigsForTargetStmt *sql.Stmt
+ deleteCrossSigningSigsForTargetStmt *sql.Stmt
+}
+
+func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
+ s := &crossSigningSigsStatements{
+ db: db,
+ }
+ _, err := db.Exec(crossSigningSigsSchema)
+ if err != nil {
+ return nil, err
+ }
+
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "keyserver: cross signing signature indexes",
+ Up: deltas.UpFixCrossSigningSignatureIndexes,
+ })
+ if err = m.Up(context.Background()); err != nil {
+ return nil, err
+ }
+
+ return s, sqlutil.StatementList{
+ {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
+ {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
+ {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
+ }.Prepare(db)
+}
+
+func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
+ ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+) (r types.CrossSigningSigMap, err error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed")
+ r = types.CrossSigningSigMap{}
+ for rows.Next() {
+ var userID string
+ var keyID gomatrixserverlib.KeyID
+ var signature gomatrixserverlib.Base64Bytes
+ if err := rows.Scan(&userID, &keyID, &signature); err != nil {
+ return nil, err
+ }
+ if _, ok := r[userID]; !ok {
+ r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ r[userID][keyID] = signature
+ }
+ return
+}
+
+func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
+ ctx context.Context, txn *sql.Tx,
+ originUserID string, originKeyID gomatrixserverlib.KeyID,
+ targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+ signature gomatrixserverlib.Base64Bytes,
+) error {
+ if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
+ return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
+ }
+ return nil
+}
+
+func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
+ ctx context.Context, txn *sql.Tx,
+ targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+) error {
+ if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
+ return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/postgres/deltas/2022012016470000_key_changes.go b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go
new file mode 100644
index 00000000..0cfe9e79
--- /dev/null
+++ b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go
@@ -0,0 +1,69 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+)
+
+func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
+ // start counting from the last max offset, else 0. We need to do a count(*) first to see if there
+ // even are entries in this table to know if we can query for log_offset. Without the count then
+ // the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't
+ // exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/
+ var count int
+ _ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count)
+ if count > 0 {
+ var maxOffset int64
+ _ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset)
+ if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil {
+ return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err)
+ }
+ }
+
+ _, err := tx.ExecContext(ctx, `
+ -- make the new table
+ DROP TABLE IF EXISTS keyserver_key_changes;
+ CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
+ user_id TEXT NOT NULL,
+ CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
+ );
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute upgrade: %w", err)
+ }
+ return nil
+}
+
+func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
+ DROP SEQUENCE IF EXISTS keyserver_key_changes_seq;
+ DROP TABLE IF EXISTS keyserver_key_changes;
+ CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ partition BIGINT NOT NULL,
+ log_offset BIGINT NOT NULL,
+ user_id TEXT NOT NULL,
+ CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
+ );
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute downgrade: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go
new file mode 100644
index 00000000..1a3d4fee
--- /dev/null
+++ b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go
@@ -0,0 +1,47 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+)
+
+func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
+ ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id);
+
+ CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute upgrade: %w", err)
+ }
+ return nil
+}
+
+func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
+ ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id);
+
+ DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute downgrade: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/postgres/device_keys_table.go b/userapi/storage/postgres/device_keys_table.go
new file mode 100644
index 00000000..a9203857
--- /dev/null
+++ b/userapi/storage/postgres/device_keys_table.go
@@ -0,0 +1,213 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "time"
+
+ "github.com/lib/pq"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+var deviceKeysSchema = `
+-- Stores device keys for users
+CREATE TABLE IF NOT EXISTS keyserver_device_keys (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ ts_added_secs BIGINT NOT NULL,
+ key_json TEXT NOT NULL,
+ -- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
+ -- This means we do not store an unbounded append-only log of device keys, which is not actually
+ -- required in the spec because in the event of a missed update the server fetches the entire
+ -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
+ stream_id BIGINT NOT NULL,
+ display_name TEXT,
+ -- Clobber based on tuple of user/device.
+ CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
+);
+`
+
+const upsertDeviceKeysSQL = "" +
+ "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)" +
+ " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
+ " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
+
+const selectDeviceKeysSQL = "" +
+ "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+
+const selectBatchDeviceKeysSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
+
+const selectBatchDeviceKeysWithEmptiesSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
+
+const selectMaxStreamForUserSQL = "" +
+ "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
+
+const countStreamIDsForUserSQL = "" +
+ "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
+
+const deleteDeviceKeysSQL = "" +
+ "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+
+const deleteAllDeviceKeysSQL = "" +
+ "DELETE FROM keyserver_device_keys WHERE user_id=$1"
+
+type deviceKeysStatements struct {
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
+ countStreamIDsForUserStmt *sql.Stmt
+ deleteDeviceKeysStmt *sql.Stmt
+ deleteAllDeviceKeysStmt *sql.Stmt
+}
+
+func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
+ s := &deviceKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(deviceKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
+ {&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
+ {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
+ {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL},
+ {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
+ {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
+ }.Prepare(db)
+}
+
+func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
+ for i, key := range keys {
+ var keyJSONStr string
+ var streamID int64
+ var displayName sql.NullString
+ err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
+ if err != nil && err != sql.ErrNoRows {
+ return err
+ }
+ // this will be '' when there is no device
+ keys[i].Type = api.TypeDeviceKeyUpdate
+ keys[i].KeyJSON = []byte(keyJSONStr)
+ keys[i].StreamID = streamID
+ if displayName.Valid {
+ keys[i].DisplayName = displayName.String
+ }
+ }
+ return nil
+}
+
+func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
+ // nullable if there are no results
+ var nullStream sql.NullInt64
+ err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
+ if err == sql.ErrNoRows {
+ err = nil
+ }
+ if nullStream.Valid {
+ streamID = nullStream.Int64
+ }
+ return
+}
+
+func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
+ // nullable if there are no results
+ var count sql.NullInt32
+ err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count)
+ if err != nil {
+ return 0, err
+ }
+ if count.Valid {
+ return int(count.Int32), nil
+ }
+ return 0, nil
+}
+
+func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
+ for _, key := range keys {
+ now := time.Now().Unix()
+ _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
+ ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
+ )
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
+ return err
+}
+
+func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
+ return err
+}
+
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
+ var stmt *sql.Stmt
+ if includeEmpty {
+ stmt = s.selectBatchDeviceKeysWithEmptiesStmt
+ } else {
+ stmt = s.selectBatchDeviceKeysStmt
+ }
+ rows, err := stmt.QueryContext(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
+ deviceIDMap := make(map[string]bool)
+ for _, d := range deviceIDs {
+ deviceIDMap[d] = true
+ }
+ var result []api.DeviceMessage
+ var displayName sql.NullString
+ for rows.Next() {
+ dk := api.DeviceMessage{
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ UserID: userID,
+ },
+ }
+ if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
+ return nil, err
+ }
+ if displayName.Valid {
+ dk.DisplayName = displayName.String
+ }
+ // include the key if we want all keys (no device) or it was asked
+ if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
+ result = append(result, dk)
+ }
+ }
+ return result, rows.Err()
+}
diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go
index 7481ac5b..88f8839c 100644
--- a/userapi/storage/postgres/devices_table.go
+++ b/userapi/storage/postgres/devices_table.go
@@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice(
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
}
- return &api.Device{
+ dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@@ -168,7 +168,11 @@ func (s *devicesStatements) InsertDevice(
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
- }, nil
+ }
+ if displayName != nil {
+ dev.DisplayName = *displayName
+ }
+ return dev, nil
}
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go
index 7b58f7ba..91a34c35 100644
--- a/userapi/storage/postgres/key_backup_table.go
+++ b/userapi/storage/postgres/key_backup_table.go
@@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
const countKeysSQL = "" +
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
-const selectKeysSQL = "" +
+const selectBackupKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
@@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
{&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL},
- {&s.selectKeysStmt, selectKeysSQL},
+ {&s.selectKeysStmt, selectBackupKeysSQL},
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
}.Prepare(db)
diff --git a/userapi/storage/postgres/key_changes_table.go b/userapi/storage/postgres/key_changes_table.go
new file mode 100644
index 00000000..a0049414
--- /dev/null
+++ b/userapi/storage/postgres/key_changes_table.go
@@ -0,0 +1,127 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+var keyChangesSchema = `
+-- Stores key change information about users. Used to determine when to send updated device lists to clients.
+CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq;
+CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
+ user_id TEXT NOT NULL,
+ CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
+);
+`
+
+// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
+// have changed, hence we can just keep bumping the change ID for this user.
+const upsertKeyChangeSQL = "" +
+ "INSERT INTO keyserver_key_changes (user_id)" +
+ " VALUES ($1)" +
+ " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" +
+ " DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" +
+ " RETURNING change_id"
+
+const selectKeyChangesSQL = "" +
+ "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
+
+type keyChangesStatements struct {
+ db *sql.DB
+ upsertKeyChangeStmt *sql.Stmt
+ selectKeyChangesStmt *sql.Stmt
+}
+
+func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
+ s := &keyChangesStatements{
+ db: db,
+ }
+ _, err := db.Exec(keyChangesSchema)
+ if err != nil {
+ return s, err
+ }
+
+ if err = executeMigration(context.Background(), db); err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
+ {&s.selectKeyChangesStmt, selectKeyChangesSQL},
+ }.Prepare(db)
+}
+
+func executeMigration(ctx context.Context, db *sql.DB) error {
+ // TODO: Remove when we are sure we are not having goose artefacts in the db
+ // This forces an error, which indicates the migration is already applied, since the
+ // column partition was removed from the table
+ migrationName := "keyserver: refactor key changes"
+
+ var cName string
+ err := db.QueryRowContext(ctx, "select column_name from information_schema.columns where table_name = 'keyserver_key_changes' AND column_name = 'partition'").Scan(&cName)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed
+ if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil {
+ return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
+ }
+ return nil
+ }
+ return err
+ }
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: migrationName,
+ Up: deltas.UpRefactorKeyChanges,
+ })
+
+ return m.Up(ctx)
+}
+
+func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
+ err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
+ return
+}
+
+func (s *keyChangesStatements) SelectKeyChanges(
+ ctx context.Context, fromOffset, toOffset int64,
+) (userIDs []string, latestOffset int64, err error) {
+ latestOffset = fromOffset
+ rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
+ if err != nil {
+ return nil, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
+ for rows.Next() {
+ var userID string
+ var offset int64
+ if err := rows.Scan(&userID, &offset); err != nil {
+ return nil, 0, err
+ }
+ if offset > latestOffset {
+ latestOffset = offset
+ }
+ userIDs = append(userIDs, userID)
+ }
+ return
+}
diff --git a/userapi/storage/postgres/one_time_keys_table.go b/userapi/storage/postgres/one_time_keys_table.go
new file mode 100644
index 00000000..972a5914
--- /dev/null
+++ b/userapi/storage/postgres/one_time_keys_table.go
@@ -0,0 +1,194 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "time"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+var oneTimeKeysSchema = `
+-- Stores one-time public keys for users
+CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ key_id TEXT NOT NULL,
+ algorithm TEXT NOT NULL,
+ ts_added_secs BIGINT NOT NULL,
+ key_json TEXT NOT NULL,
+ -- Clobber based on 4-uple of user/device/key/algorithm.
+ CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id);
+`
+
+const upsertKeysSQL = "" +
+ "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)" +
+ " ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" +
+ " DO UPDATE SET key_json = $6"
+
+const selectOneTimeKeysSQL = "" +
+ "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
+
+const selectKeysCountSQL = "" +
+ "SELECT algorithm, COUNT(key_id) FROM " +
+ " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" +
+ " x GROUP BY algorithm"
+
+const deleteOneTimeKeySQL = "" +
+ "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
+
+const selectKeyByAlgorithmSQL = "" +
+ "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
+
+const deleteOneTimeKeysSQL = "" +
+ "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
+
+type oneTimeKeysStatements struct {
+ db *sql.DB
+ upsertKeysStmt *sql.Stmt
+ selectKeysStmt *sql.Stmt
+ selectKeysCountStmt *sql.Stmt
+ selectKeyByAlgorithmStmt *sql.Stmt
+ deleteOneTimeKeyStmt *sql.Stmt
+ deleteOneTimeKeysStmt *sql.Stmt
+}
+
+func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
+ s := &oneTimeKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(oneTimeKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertKeysStmt, upsertKeysSQL},
+ {&s.selectKeysStmt, selectOneTimeKeysSQL},
+ {&s.selectKeysCountStmt, selectKeysCountSQL},
+ {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
+ {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
+ {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
+ }.Prepare(db)
+}
+
+func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
+ rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms))
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
+
+ result := make(map[string]json.RawMessage)
+ var (
+ algorithmWithID string
+ keyJSONStr string
+ )
+ for rows.Next() {
+ if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
+ return nil, err
+ }
+ result[algorithmWithID] = json.RawMessage(keyJSONStr)
+ }
+ return result, rows.Err()
+}
+
+func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
+ counts := &api.OneTimeKeysCount{
+ DeviceID: deviceID,
+ UserID: userID,
+ KeyCount: make(map[string]int),
+ }
+ rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
+ for rows.Next() {
+ var algorithm string
+ var count int
+ if err = rows.Scan(&algorithm, &count); err != nil {
+ return nil, err
+ }
+ counts.KeyCount[algorithm] = count
+ }
+ return counts, nil
+}
+
+func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
+ now := time.Now().Unix()
+ counts := &api.OneTimeKeysCount{
+ DeviceID: keys.DeviceID,
+ UserID: keys.UserID,
+ KeyCount: make(map[string]int),
+ }
+ for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
+ algo, keyID := keys.Split(keyIDWithAlgo)
+ _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
+ ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
+ )
+ if err != nil {
+ return nil, err
+ }
+ }
+ rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
+ for rows.Next() {
+ var algorithm string
+ var count int
+ if err = rows.Scan(&algorithm, &count); err != nil {
+ return nil, err
+ }
+ counts.KeyCount[algorithm] = count
+ }
+
+ return counts, rows.Err()
+}
+
+func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
+ ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
+) (map[string]json.RawMessage, error) {
+ var keyID string
+ var keyJSON string
+ err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
+ return map[string]json.RawMessage{
+ algorithm + ":" + keyID: json.RawMessage(keyJSON),
+ }, err
+}
+
+func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID)
+ return err
+}
diff --git a/userapi/storage/postgres/stale_device_lists.go b/userapi/storage/postgres/stale_device_lists.go
new file mode 100644
index 00000000..c823b58c
--- /dev/null
+++ b/userapi/storage/postgres/stale_device_lists.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "time"
+
+ "github.com/lib/pq"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var staleDeviceListsSchema = `
+-- Stores whether a user's device lists are stale or not.
+CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
+ user_id TEXT PRIMARY KEY NOT NULL,
+ domain TEXT NOT NULL,
+ is_stale BOOLEAN NOT NULL,
+ ts_added_secs BIGINT NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
+`
+
+const upsertStaleDeviceListSQL = "" +
+ "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
+ " VALUES ($1, $2, $3, $4)" +
+ " ON CONFLICT (user_id)" +
+ " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
+
+const selectStaleDeviceListsWithDomainsSQL = "" +
+ "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
+
+const selectStaleDeviceListsSQL = "" +
+ "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
+
+const deleteStaleDevicesSQL = "" +
+ "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
+
+type staleDeviceListsStatements struct {
+ upsertStaleDeviceListStmt *sql.Stmt
+ selectStaleDeviceListsWithDomainsStmt *sql.Stmt
+ selectStaleDeviceListsStmt *sql.Stmt
+ deleteStaleDeviceListsStmt *sql.Stmt
+}
+
+func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
+ s := &staleDeviceListsStatements{}
+ _, err := db.Exec(staleDeviceListsSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
+ {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
+ {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
+ {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL},
+ }.Prepare(db)
+}
+
+func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
+ _, domain, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return err
+ }
+ _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
+ return err
+}
+
+func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
+ // we only query for 1 domain or all domains so optimise for those use cases
+ if len(domains) == 0 {
+ rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
+ if err != nil {
+ return nil, err
+ }
+ return rowsToUserIDs(ctx, rows)
+ }
+ var result []string
+ for _, domain := range domains {
+ rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
+ if err != nil {
+ return nil, err
+ }
+ userIDs, err := rowsToUserIDs(ctx, rows)
+ if err != nil {
+ return nil, err
+ }
+ result = append(result, userIDs...)
+ }
+ return result, nil
+}
+
+// DeleteStaleDeviceLists removes users from stale device lists
+func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
+ ctx context.Context, txn *sql.Tx, userIDs []string,
+) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt)
+ _, err := stmt.ExecContext(ctx, pq.Array(userIDs))
+ return err
+}
+
+func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
+ defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
+ for rows.Next() {
+ var userID string
+ if err := rows.Scan(&userID); err != nil {
+ return nil, err
+ }
+ result = append(result, userID)
+ }
+ return result, rows.Err()
+}
diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go
index 92dc4808..673d123b 100644
--- a/userapi/storage/postgres/storage.go
+++ b/userapi/storage/postgres/storage.go
@@ -136,3 +136,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}
+
+func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
+ db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter())
+ if err != nil {
+ return nil, err
+ }
+ otk, err := NewPostgresOneTimeKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ dk, err := NewPostgresDeviceKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ kc, err := NewPostgresKeyChangesTable(db)
+ if err != nil {
+ return nil, err
+ }
+ sdl, err := NewPostgresStaleDeviceListsTable(db)
+ if err != nil {
+ return nil, err
+ }
+ csk, err := NewPostgresCrossSigningKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ css, err := NewPostgresCrossSigningSigsTable(db)
+ if err != nil {
+ return nil, err
+ }
+
+ return &shared.KeyDatabase{
+ OneTimeKeysTable: otk,
+ DeviceKeysTable: dk,
+ KeyChangesTable: kc,
+ StaleDeviceListsTable: sdl,
+ CrossSigningKeysTable: csk,
+ CrossSigningSigsTable: css,
+ Writer: writer,
+ }, nil
+}
diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go
index bf94f14d..d3272a03 100644
--- a/userapi/storage/shared/storage.go
+++ b/userapi/storage/shared/storage.go
@@ -59,6 +59,17 @@ type Database struct {
OpenIDTokenLifetimeMS int64
}
+type KeyDatabase struct {
+ OneTimeKeysTable tables.OneTimeKeys
+ DeviceKeysTable tables.DeviceKeys
+ KeyChangesTable tables.KeyChanges
+ StaleDeviceListsTable tables.StaleDeviceLists
+ CrossSigningKeysTable tables.CrossSigningKeys
+ CrossSigningSigsTable tables.CrossSigningSigs
+ DB *sql.DB
+ Writer sqlutil.Writer
+}
+
const (
// The length of generated device IDs
deviceIDByteLength = 6
@@ -875,3 +886,227 @@ func (d *Database) DailyRoomsMessages(
) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) {
return d.Stats.DailyRoomsMessages(ctx, nil, serverName)
}
+
+//
+
+func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
+ return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
+}
+
+func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
+ _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys)
+ return err
+ })
+ return
+}
+
+func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
+ return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
+}
+
+func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
+ return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
+}
+
+func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
+ count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
+ if err != nil {
+ return false, err
+ }
+ return count == len(prevIDs), nil
+}
+
+func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for _, userID := range clearUserIDs {
+ err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
+ if err != nil {
+ return err
+ }
+ }
+ return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
+ })
+}
+
+func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
+ // work out the latest stream IDs for each user
+ userIDToStreamID := make(map[string]int64)
+ for _, k := range keys {
+ userIDToStreamID[k.UserID] = 0
+ }
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for userID := range userIDToStreamID {
+ streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
+ if err != nil {
+ return err
+ }
+ userIDToStreamID[userID] = streamID
+ }
+ // set the stream IDs for each key
+ for i := range keys {
+ k := keys[i]
+ userIDToStreamID[k.UserID]++ // start stream from 1
+ k.StreamID = userIDToStreamID[k.UserID]
+ keys[i] = k
+ }
+ return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
+ })
+}
+
+func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
+ return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
+}
+
+func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
+ var result []api.OneTimeKeys
+ err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for userID, deviceToAlgo := range userToDeviceToAlgorithm {
+ for deviceID, algo := range deviceToAlgo {
+ keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
+ if err != nil {
+ return err
+ }
+ if keyJSON != nil {
+ result = append(result, api.OneTimeKeys{
+ UserID: userID,
+ DeviceID: deviceID,
+ KeyJSON: keyJSON,
+ })
+ }
+ }
+ }
+ return nil
+ })
+ return result, err
+}
+
+func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
+ err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
+ id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
+ return err
+ })
+ return
+}
+
+func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
+ return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
+}
+
+// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
+// If no domains are given, all user IDs with stale device lists are returned.
+func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
+ return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
+}
+
+// MarkDeviceListStale sets the stale bit for this user to isStale.
+func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
+ return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
+ return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
+ })
+}
+
+// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
+// cross-signing signatures relating to that device.
+func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for _, deviceID := range deviceIDs {
+ if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows {
+ return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err)
+ }
+ if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
+ return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err)
+ }
+ if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
+ return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err)
+ }
+ }
+ return nil
+ })
+}
+
+// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
+func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) {
+ keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
+ if err != nil {
+ return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err)
+ }
+ results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
+ for purpose, key := range keyMap {
+ keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode())
+ result := gomatrixserverlib.CrossSigningKey{
+ UserID: userID,
+ Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose},
+ Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
+ keyID: key,
+ },
+ }
+ sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID)
+ if err != nil {
+ continue
+ }
+ for sigUserID, forSigUserID := range sigMap {
+ if userID != sigUserID {
+ continue
+ }
+ if result.Signatures == nil {
+ result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ if _, ok := result.Signatures[sigUserID]; !ok {
+ result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ for sigKeyID, sigBytes := range forSigUserID {
+ result.Signatures[sigUserID][sigKeyID] = sigBytes
+ }
+ }
+ results[purpose] = result
+ }
+ return results, nil
+}
+
+// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
+func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) {
+ return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
+}
+
+// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
+func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
+ return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID)
+}
+
+// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
+func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for keyType, keyData := range keyMap {
+ if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
+ return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
+ }
+ }
+ return nil
+ })
+}
+
+// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
+func (d *KeyDatabase) StoreCrossSigningSigsForTarget(
+ ctx context.Context,
+ originUserID string, originKeyID gomatrixserverlib.KeyID,
+ targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+ signature gomatrixserverlib.Base64Bytes,
+) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
+ return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err)
+ }
+ return nil
+ })
+}
+
+// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
+func (d *KeyDatabase) DeleteStaleDeviceLists(
+ ctx context.Context,
+ userIDs []string,
+) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
+ })
+}
diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go
new file mode 100644
index 00000000..10721fcc
--- /dev/null
+++ b/userapi/storage/sqlite3/cross_signing_keys_table.go
@@ -0,0 +1,101 @@
+// Copyright 2021 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var crossSigningKeysSchema = `
+CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
+ user_id TEXT NOT NULL,
+ key_type INTEGER NOT NULL,
+ key_data TEXT NOT NULL,
+ PRIMARY KEY (user_id, key_type)
+);
+`
+
+const selectCrossSigningKeysForUserSQL = "" +
+ "SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
+ " WHERE user_id = $1"
+
+const upsertCrossSigningKeysForUserSQL = "" +
+ "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
+ " VALUES($1, $2, $3)"
+
+type crossSigningKeysStatements struct {
+ db *sql.DB
+ selectCrossSigningKeysForUserStmt *sql.Stmt
+ upsertCrossSigningKeysForUserStmt *sql.Stmt
+}
+
+func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
+ s := &crossSigningKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(crossSigningKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
+ {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
+ }.Prepare(db)
+}
+
+func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
+ ctx context.Context, txn *sql.Tx, userID string,
+) (r types.CrossSigningKeyMap, err error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
+ r = types.CrossSigningKeyMap{}
+ for rows.Next() {
+ var keyTypeInt int16
+ var keyData gomatrixserverlib.Base64Bytes
+ if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
+ return nil, err
+ }
+ keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
+ if !ok {
+ return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
+ }
+ r[keyType] = keyData
+ }
+ return
+}
+
+func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
+ ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
+) error {
+ keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
+ if !ok {
+ return fmt.Errorf("unknown key purpose %q", keyType)
+ }
+ if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
+ return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/cross_signing_sigs_table.go b/userapi/storage/sqlite3/cross_signing_sigs_table.go
new file mode 100644
index 00000000..2be00c9c
--- /dev/null
+++ b/userapi/storage/sqlite3/cross_signing_sigs_table.go
@@ -0,0 +1,129 @@
+// Copyright 2021 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var crossSigningSigsSchema = `
+CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
+ origin_user_id TEXT NOT NULL,
+ origin_key_id TEXT NOT NULL,
+ target_user_id TEXT NOT NULL,
+ target_key_id TEXT NOT NULL,
+ signature TEXT NOT NULL,
+ PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
+`
+
+const selectCrossSigningSigsForTargetSQL = "" +
+ "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
+ " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4"
+
+const upsertCrossSigningSigsForTargetSQL = "" +
+ "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
+ " VALUES($1, $2, $3, $4, $5)"
+
+const deleteCrossSigningSigsForTargetSQL = "" +
+ "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
+
+type crossSigningSigsStatements struct {
+ db *sql.DB
+ selectCrossSigningSigsForTargetStmt *sql.Stmt
+ upsertCrossSigningSigsForTargetStmt *sql.Stmt
+ deleteCrossSigningSigsForTargetStmt *sql.Stmt
+}
+
+func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
+ s := &crossSigningSigsStatements{
+ db: db,
+ }
+ _, err := db.Exec(crossSigningSigsSchema)
+ if err != nil {
+ return nil, err
+ }
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "keyserver: cross signing signature indexes",
+ Up: deltas.UpFixCrossSigningSignatureIndexes,
+ })
+ if err = m.Up(context.Background()); err != nil {
+ return nil, err
+ }
+
+ return s, sqlutil.StatementList{
+ {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
+ {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
+ {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
+ }.Prepare(db)
+}
+
+func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
+ ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+) (r types.CrossSigningSigMap, err error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForOriginTargetStmt: rows.close() failed")
+ r = types.CrossSigningSigMap{}
+ for rows.Next() {
+ var userID string
+ var keyID gomatrixserverlib.KeyID
+ var signature gomatrixserverlib.Base64Bytes
+ if err := rows.Scan(&userID, &keyID, &signature); err != nil {
+ return nil, err
+ }
+ if _, ok := r[userID]; !ok {
+ r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ r[userID][keyID] = signature
+ }
+ return
+}
+
+func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
+ ctx context.Context, txn *sql.Tx,
+ originUserID string, originKeyID gomatrixserverlib.KeyID,
+ targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+ signature gomatrixserverlib.Base64Bytes,
+) error {
+ if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
+ return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
+ }
+ return nil
+}
+
+func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
+ ctx context.Context, txn *sql.Tx,
+ targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+) error {
+ if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
+ return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go
new file mode 100644
index 00000000..cd0f19df
--- /dev/null
+++ b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go
@@ -0,0 +1,66 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+)
+
+func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
+ // start counting from the last max offset, else 0.
+ var maxOffset int64
+ var userID string
+ _ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset)
+
+ _, err := tx.ExecContext(ctx, `
+ -- make the new table
+ DROP TABLE IF EXISTS keyserver_key_changes;
+ CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ change_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ -- The key owner
+ user_id TEXT NOT NULL,
+ UNIQUE (user_id)
+ );
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute upgrade: %w", err)
+ }
+ // to start counting from maxOffset, insert a row with that value
+ if userID != "" {
+ _, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID)
+ return err
+ }
+ return nil
+}
+
+func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
+ DROP TABLE IF EXISTS keyserver_key_changes;
+ CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ partition BIGINT NOT NULL,
+ offset BIGINT NOT NULL,
+ -- The key owner
+ user_id TEXT NOT NULL,
+ UNIQUE (partition, offset)
+ );
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute downgrade: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go
new file mode 100644
index 00000000..d4e38dea
--- /dev/null
+++ b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go
@@ -0,0 +1,71 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+)
+
+func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
+ origin_user_id TEXT NOT NULL,
+ origin_key_id TEXT NOT NULL,
+ target_user_id TEXT NOT NULL,
+ target_key_id TEXT NOT NULL,
+ signature TEXT NOT NULL,
+ PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
+ );
+
+ INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
+ SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
+
+ DROP TABLE keyserver_cross_signing_sigs;
+ ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
+
+ CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute upgrade: %w", err)
+ }
+ return nil
+}
+
+func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
+ origin_user_id TEXT NOT NULL,
+ origin_key_id TEXT NOT NULL,
+ target_user_id TEXT NOT NULL,
+ target_key_id TEXT NOT NULL,
+ signature TEXT NOT NULL,
+ PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
+ );
+
+ INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
+ SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
+
+ DROP TABLE keyserver_cross_signing_sigs;
+ ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
+
+ DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute downgrade: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/device_keys_table.go b/userapi/storage/sqlite3/device_keys_table.go
new file mode 100644
index 00000000..15e69cc4
--- /dev/null
+++ b/userapi/storage/sqlite3/device_keys_table.go
@@ -0,0 +1,213 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "strings"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+var deviceKeysSchema = `
+-- Stores device keys for users
+CREATE TABLE IF NOT EXISTS keyserver_device_keys (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ ts_added_secs BIGINT NOT NULL,
+ key_json TEXT NOT NULL,
+ stream_id BIGINT NOT NULL,
+ display_name TEXT,
+ -- Clobber based on tuple of user/device.
+ UNIQUE (user_id, device_id)
+);
+`
+
+const upsertDeviceKeysSQL = "" +
+ "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)" +
+ " ON CONFLICT (user_id, device_id)" +
+ " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
+
+const selectDeviceKeysSQL = "" +
+ "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+
+const selectBatchDeviceKeysSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
+
+const selectBatchDeviceKeysWithEmptiesSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
+
+const selectMaxStreamForUserSQL = "" +
+ "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
+
+const countStreamIDsForUserSQL = "" +
+ "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
+
+const deleteDeviceKeysSQL = "" +
+ "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+
+const deleteAllDeviceKeysSQL = "" +
+ "DELETE FROM keyserver_device_keys WHERE user_id=$1"
+
+type deviceKeysStatements struct {
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
+ deleteDeviceKeysStmt *sql.Stmt
+ deleteAllDeviceKeysStmt *sql.Stmt
+}
+
+func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
+ s := &deviceKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(deviceKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
+ {&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
+ {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
+ // {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, // prepared at runtime
+ {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
+ {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
+ }.Prepare(db)
+}
+
+func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
+ return err
+}
+
+func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
+ return err
+}
+
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
+ deviceIDMap := make(map[string]bool)
+ for _, d := range deviceIDs {
+ deviceIDMap[d] = true
+ }
+ var stmt *sql.Stmt
+ if includeEmpty {
+ stmt = s.selectBatchDeviceKeysWithEmptiesStmt
+ } else {
+ stmt = s.selectBatchDeviceKeysStmt
+ }
+ rows, err := stmt.QueryContext(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
+ var result []api.DeviceMessage
+ var displayName sql.NullString
+ for rows.Next() {
+ dk := api.DeviceMessage{
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ UserID: userID,
+ },
+ }
+ if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
+ return nil, err
+ }
+ if displayName.Valid {
+ dk.DisplayName = displayName.String
+ }
+ // include the key if we want all keys (no device) or it was asked
+ if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
+ result = append(result, dk)
+ }
+ }
+ return result, rows.Err()
+}
+
+func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
+ for i, key := range keys {
+ var keyJSONStr string
+ var streamID int64
+ var displayName sql.NullString
+ err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
+ if err != nil && err != sql.ErrNoRows {
+ return err
+ }
+ // this will be '' when there is no device
+ keys[i].Type = api.TypeDeviceKeyUpdate
+ keys[i].KeyJSON = []byte(keyJSONStr)
+ keys[i].StreamID = streamID
+ if displayName.Valid {
+ keys[i].DisplayName = displayName.String
+ }
+ }
+ return nil
+}
+
+func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
+ // nullable if there are no results
+ var nullStream sql.NullInt64
+ err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
+ if err == sql.ErrNoRows {
+ err = nil
+ }
+ if nullStream.Valid {
+ streamID = nullStream.Int64
+ }
+ return
+}
+
+func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
+ iStreamIDs := make([]interface{}, len(streamIDs)+1)
+ iStreamIDs[0] = userID
+ for i := range streamIDs {
+ iStreamIDs[i+1] = streamIDs[i]
+ }
+ query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
+ // nullable if there are no results
+ var count sql.NullInt64
+ err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
+ if err != nil {
+ return 0, err
+ }
+ if count.Valid {
+ return int(count.Int64), nil
+ }
+ return 0, nil
+}
+
+func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
+ for _, key := range keys {
+ now := time.Now().Unix()
+ _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
+ ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
+ )
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go
index 449e4549..65e17527 100644
--- a/userapi/storage/sqlite3/devices_table.go
+++ b/userapi/storage/sqlite3/devices_table.go
@@ -151,7 +151,7 @@ func (s *devicesStatements) InsertDevice(
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
- return &api.Device{
+ dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@@ -159,7 +159,11 @@ func (s *devicesStatements) InsertDevice(
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
- }, nil
+ }
+ if displayName != nil {
+ dev.DisplayName = *displayName
+ }
+ return dev, nil
}
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
@@ -172,7 +176,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
- return &api.Device{
+ dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@@ -180,7 +184,11 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
- }, nil
+ }
+ if displayName != nil {
+ dev.DisplayName = *displayName
+ }
+ return dev, nil
}
func (s *devicesStatements) DeleteDevice(
@@ -202,6 +210,7 @@ func (s *devicesStatements) DeleteDevices(
if err != nil {
return err
}
+ defer internal.CloseAndLogIfError(ctx, prep, "DeleteDevices.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+2)
params[0] = localpart
diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go
index 7883ffb1..ed274631 100644
--- a/userapi/storage/sqlite3/key_backup_table.go
+++ b/userapi/storage/sqlite3/key_backup_table.go
@@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
const countKeysSQL = "" +
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
-const selectKeysSQL = "" +
+const selectBackupKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
@@ -83,7 +83,7 @@ func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
{&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL},
- {&s.selectKeysStmt, selectKeysSQL},
+ {&s.selectKeysStmt, selectBackupKeysSQL},
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
}.Prepare(db)
diff --git a/userapi/storage/sqlite3/key_changes_table.go b/userapi/storage/sqlite3/key_changes_table.go
new file mode 100644
index 00000000..923bb57e
--- /dev/null
+++ b/userapi/storage/sqlite3/key_changes_table.go
@@ -0,0 +1,125 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+var keyChangesSchema = `
+-- Stores key change information about users. Used to determine when to send updated device lists to clients.
+CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ change_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ -- The key owner
+ user_id TEXT NOT NULL,
+ UNIQUE (user_id)
+);
+`
+
+// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
+// have changed, hence we can just keep bumping the change ID for this user.
+const upsertKeyChangeSQL = "" +
+ "INSERT OR REPLACE INTO keyserver_key_changes (user_id)" +
+ " VALUES ($1)" +
+ " RETURNING change_id"
+
+const selectKeyChangesSQL = "" +
+ "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
+
+type keyChangesStatements struct {
+ db *sql.DB
+ upsertKeyChangeStmt *sql.Stmt
+ selectKeyChangesStmt *sql.Stmt
+}
+
+func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
+ s := &keyChangesStatements{
+ db: db,
+ }
+ _, err := db.Exec(keyChangesSchema)
+ if err != nil {
+ return s, err
+ }
+
+ if err = executeMigration(context.Background(), db); err != nil {
+ return nil, err
+ }
+
+ return s, sqlutil.StatementList{
+ {&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
+ {&s.selectKeyChangesStmt, selectKeyChangesSQL},
+ }.Prepare(db)
+}
+
+func executeMigration(ctx context.Context, db *sql.DB) error {
+ // TODO: Remove when we are sure we are not having goose artefacts in the db
+ // This forces an error, which indicates the migration is already applied, since the
+ // column partition was removed from the table
+ migrationName := "keyserver: refactor key changes"
+
+ var cName string
+ err := db.QueryRowContext(ctx, `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'keyserver_key_changes' AND p.name = 'partition'`).Scan(&cName)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed
+ if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil {
+ return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
+ }
+ return nil
+ }
+ return err
+ }
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: migrationName,
+ Up: deltas.UpRefactorKeyChanges,
+ })
+ return m.Up(ctx)
+}
+
+func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
+ err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
+ return
+}
+
+func (s *keyChangesStatements) SelectKeyChanges(
+ ctx context.Context, fromOffset, toOffset int64,
+) (userIDs []string, latestOffset int64, err error) {
+ latestOffset = fromOffset
+ rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
+ if err != nil {
+ return nil, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
+ for rows.Next() {
+ var userID string
+ var offset int64
+ if err := rows.Scan(&userID, &offset); err != nil {
+ return nil, 0, err
+ }
+ if offset > latestOffset {
+ latestOffset = offset
+ }
+ userIDs = append(userIDs, userID)
+ }
+ return
+}
diff --git a/userapi/storage/sqlite3/one_time_keys_table.go b/userapi/storage/sqlite3/one_time_keys_table.go
new file mode 100644
index 00000000..a992d399
--- /dev/null
+++ b/userapi/storage/sqlite3/one_time_keys_table.go
@@ -0,0 +1,208 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+var oneTimeKeysSchema = `
+-- Stores one-time public keys for users
+CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ key_id TEXT NOT NULL,
+ algorithm TEXT NOT NULL,
+ ts_added_secs BIGINT NOT NULL,
+ key_json TEXT NOT NULL,
+ -- Clobber based on 4-uple of user/device/key/algorithm.
+ UNIQUE (user_id, device_id, key_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id);
+`
+
+const upsertKeysSQL = "" +
+ "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)" +
+ " ON CONFLICT (user_id, device_id, key_id, algorithm)" +
+ " DO UPDATE SET key_json = $6"
+
+const selectOneTimeKeysSQL = "" +
+ "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
+
+const selectKeysCountSQL = "" +
+ "SELECT algorithm, COUNT(key_id) FROM " +
+ " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" +
+ " x GROUP BY algorithm"
+
+const deleteOneTimeKeySQL = "" +
+ "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
+
+const selectKeyByAlgorithmSQL = "" +
+ "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
+
+const deleteOneTimeKeysSQL = "" +
+ "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
+
+type oneTimeKeysStatements struct {
+ db *sql.DB
+ upsertKeysStmt *sql.Stmt
+ selectKeysStmt *sql.Stmt
+ selectKeysCountStmt *sql.Stmt
+ selectKeyByAlgorithmStmt *sql.Stmt
+ deleteOneTimeKeyStmt *sql.Stmt
+ deleteOneTimeKeysStmt *sql.Stmt
+}
+
+func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
+ s := &oneTimeKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(oneTimeKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertKeysStmt, upsertKeysSQL},
+ {&s.selectKeysStmt, selectOneTimeKeysSQL},
+ {&s.selectKeysCountStmt, selectKeysCountSQL},
+ {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
+ {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
+ {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
+ }.Prepare(db)
+}
+
+func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
+ rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
+
+ wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
+ for _, ka := range keyIDsWithAlgorithms {
+ wantSet[ka] = true
+ }
+
+ result := make(map[string]json.RawMessage)
+ for rows.Next() {
+ var keyID string
+ var algorithm string
+ var keyJSONStr string
+ if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
+ return nil, err
+ }
+ keyIDWithAlgo := algorithm + ":" + keyID
+ if wantSet[keyIDWithAlgo] {
+ result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
+ }
+ }
+ return result, rows.Err()
+}
+
+func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
+ counts := &api.OneTimeKeysCount{
+ DeviceID: deviceID,
+ UserID: userID,
+ KeyCount: make(map[string]int),
+ }
+ rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
+ for rows.Next() {
+ var algorithm string
+ var count int
+ if err = rows.Scan(&algorithm, &count); err != nil {
+ return nil, err
+ }
+ counts.KeyCount[algorithm] = count
+ }
+ return counts, nil
+}
+
+func (s *oneTimeKeysStatements) InsertOneTimeKeys(
+ ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys,
+) (*api.OneTimeKeysCount, error) {
+ now := time.Now().Unix()
+ counts := &api.OneTimeKeysCount{
+ DeviceID: keys.DeviceID,
+ UserID: keys.UserID,
+ KeyCount: make(map[string]int),
+ }
+ for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
+ algo, keyID := keys.Split(keyIDWithAlgo)
+ _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
+ ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
+ )
+ if err != nil {
+ return nil, err
+ }
+ }
+ rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
+ for rows.Next() {
+ var algorithm string
+ var count int
+ if err = rows.Scan(&algorithm, &count); err != nil {
+ return nil, err
+ }
+ counts.KeyCount[algorithm] = count
+ }
+
+ return counts, rows.Err()
+}
+
+func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
+ ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
+) (map[string]json.RawMessage, error) {
+ var keyID string
+ var keyJSON string
+ err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
+ if err != nil {
+ return nil, err
+ }
+ if keyJSON == "" {
+ return nil, nil
+ }
+ return map[string]json.RawMessage{
+ algorithm + ":" + keyID: json.RawMessage(keyJSON),
+ }, err
+}
+
+func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID)
+ return err
+}
diff --git a/userapi/storage/sqlite3/stale_device_lists.go b/userapi/storage/sqlite3/stale_device_lists.go
new file mode 100644
index 00000000..f078fc99
--- /dev/null
+++ b/userapi/storage/sqlite3/stale_device_lists.go
@@ -0,0 +1,145 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "strings"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var staleDeviceListsSchema = `
+-- Stores whether a user's device lists are stale or not.
+CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
+ user_id TEXT PRIMARY KEY NOT NULL,
+ domain TEXT NOT NULL,
+ is_stale BOOLEAN NOT NULL,
+ ts_added_secs BIGINT NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
+`
+
+const upsertStaleDeviceListSQL = "" +
+ "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
+ " VALUES ($1, $2, $3, $4)" +
+ " ON CONFLICT (user_id)" +
+ " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
+
+const selectStaleDeviceListsWithDomainsSQL = "" +
+ "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
+
+const selectStaleDeviceListsSQL = "" +
+ "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
+
+const deleteStaleDevicesSQL = "" +
+ "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)"
+
+type staleDeviceListsStatements struct {
+ db *sql.DB
+ upsertStaleDeviceListStmt *sql.Stmt
+ selectStaleDeviceListsWithDomainsStmt *sql.Stmt
+ selectStaleDeviceListsStmt *sql.Stmt
+ // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime
+}
+
+func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
+ s := &staleDeviceListsStatements{
+ db: db,
+ }
+ _, err := db.Exec(staleDeviceListsSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
+ {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
+ {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
+ // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime
+ }.Prepare(db)
+}
+
+func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
+ _, domain, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return err
+ }
+ _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
+ return err
+}
+
+func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
+ // we only query for 1 domain or all domains so optimise for those use cases
+ if len(domains) == 0 {
+ rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
+ if err != nil {
+ return nil, err
+ }
+ return rowsToUserIDs(ctx, rows)
+ }
+ var result []string
+ for _, domain := range domains {
+ rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
+ if err != nil {
+ return nil, err
+ }
+ userIDs, err := rowsToUserIDs(ctx, rows)
+ if err != nil {
+ return nil, err
+ }
+ result = append(result, userIDs...)
+ }
+ return result, nil
+}
+
+// DeleteStaleDeviceLists removes users from stale device lists
+func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
+ ctx context.Context, txn *sql.Tx, userIDs []string,
+) error {
+ qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
+ stmt, err := s.db.Prepare(qry)
+ if err != nil {
+ return err
+ }
+ defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed")
+ stmt = sqlutil.TxStmt(txn, stmt)
+
+ params := make([]any, len(userIDs))
+ for i := range userIDs {
+ params[i] = userIDs[i]
+ }
+
+ _, err = stmt.ExecContext(ctx, params...)
+ return err
+}
+
+func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
+ defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
+ for rows.Next() {
+ var userID string
+ if err := rows.Scan(&userID); err != nil {
+ return nil, err
+ }
+ result = append(result, userID)
+ }
+ return result, rows.Err()
+}
diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go
index a1365c94..72b3ba49 100644
--- a/userapi/storage/sqlite3/stats_table.go
+++ b/userapi/storage/sqlite3/stats_table.go
@@ -256,6 +256,7 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int
if err != nil {
return 0, err
}
+ defer internal.CloseAndLogIfError(ctx, queryStmt, "allUsers.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
err = stmt.QueryRowContext(ctx,
1, 2, 3, 4,
@@ -269,6 +270,7 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res
if err != nil {
return 0, err
}
+ defer internal.CloseAndLogIfError(ctx, queryStmt, "nonBridgedUsers.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
err = stmt.QueryRowContext(ctx,
1, 2, 3,
@@ -286,6 +288,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx)
if err != nil {
return nil, err
}
+ defer internal.CloseAndLogIfError(ctx, queryStmt, "registeredUserByType.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
registeredAfter := time.Now().AddDate(0, 0, -30)
diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go
index 85a1f706..0f3eeed1 100644
--- a/userapi/storage/sqlite3/storage.go
+++ b/userapi/storage/sqlite3/storage.go
@@ -30,8 +30,8 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
)
-// NewDatabase creates a new accounts and profiles database
-func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
+// NewUserDatabase creates a new accounts and profiles database
+func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
if err != nil {
return nil, err
@@ -134,3 +134,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}
+
+func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
+ db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
+ if err != nil {
+ return nil, err
+ }
+ otk, err := NewSqliteOneTimeKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ dk, err := NewSqliteDeviceKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ kc, err := NewSqliteKeyChangesTable(db)
+ if err != nil {
+ return nil, err
+ }
+ sdl, err := NewSqliteStaleDeviceListsTable(db)
+ if err != nil {
+ return nil, err
+ }
+ csk, err := NewSqliteCrossSigningKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ css, err := NewSqliteCrossSigningSigsTable(db)
+ if err != nil {
+ return nil, err
+ }
+
+ return &shared.KeyDatabase{
+ OneTimeKeysTable: otk,
+ DeviceKeysTable: dk,
+ KeyChangesTable: kc,
+ StaleDeviceListsTable: sdl,
+ CrossSigningKeysTable: csk,
+ CrossSigningSigsTable: css,
+ Writer: writer,
+ }, nil
+}
diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go
index 42221e75..0329fb46 100644
--- a/userapi/storage/storage.go
+++ b/userapi/storage/storage.go
@@ -29,15 +29,36 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
)
-// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
+// NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters
-func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
+func NewUserDatabase(
+ base *base.BaseDendrite,
+ dbProperties *config.DatabaseOptions,
+ serverName gomatrixserverlib.ServerName,
+ bcryptCost int,
+ openIDTokenLifetimeMS int64,
+ loginTokenLifetime time.Duration,
+ serverNoticesLocalpart string,
+) (UserDatabase, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
+ return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
default:
return nil, fmt.Errorf("unexpected database type")
}
}
+
+// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme)
+// and sets postgres connection parameters.
+func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) {
+ switch {
+ case dbProperties.ConnectionString.IsSQLite():
+ return sqlite3.NewKeyDatabase(base, dbProperties)
+ case dbProperties.ConnectionString.IsPostgres():
+ return postgres.NewKeyDatabase(base, dbProperties)
+ default:
+ return nil, fmt.Errorf("unexpected database type")
+ }
+}
diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go
index 23aafff0..f52e7e17 100644
--- a/userapi/storage/storage_test.go
+++ b/userapi/storage/storage_test.go
@@ -4,9 +4,12 @@ import (
"context"
"encoding/json"
"fmt"
+ "reflect"
+ "sync"
"testing"
"time"
+ "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
@@ -29,14 +32,14 @@ var (
ctx = context.Background()
)
-func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
+func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType)
- db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
+ db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
if err != nil {
- t.Fatalf("NewUserAPIDatabase returned %s", err)
+ t.Fatalf("NewUserDatabase returned %s", err)
}
return db, func() {
close()
@@ -47,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun
// Tests storing and getting account data
func Test_AccountData(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
@@ -78,7 +81,7 @@ func Test_AccountData(t *testing.T) {
// Tests the creation of accounts
func Test_Accounts(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
@@ -158,7 +161,7 @@ func Test_Devices(t *testing.T) {
accessToken := util.RandomString(16)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
@@ -238,7 +241,7 @@ func Test_KeyBackup(t *testing.T) {
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
wantAuthData := json.RawMessage("my auth data")
@@ -315,7 +318,7 @@ func Test_KeyBackup(t *testing.T) {
func Test_LoginToken(t *testing.T) {
alice := test.NewUser(t)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
// create a new token
@@ -347,7 +350,7 @@ func Test_OpenID(t *testing.T) {
token := util.RandomString(24)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS
@@ -368,7 +371,7 @@ func Test_Profile(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
// create account, which also creates a profile
@@ -417,7 +420,7 @@ func Test_Pusher(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
appID := util.RandomString(8)
@@ -468,7 +471,7 @@ func Test_ThreePID(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
threePID := util.RandomString(8)
medium := util.RandomString(8)
@@ -507,7 +510,7 @@ func Test_Notification(t *testing.T) {
room := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
// generate some dummy notifications
for i := 0; i < 10; i++ {
@@ -571,3 +574,184 @@ func Test_Notification(t *testing.T) {
assert.Equal(t, int64(0), total)
})
}
+
+func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
+ base, close := testrig.CreateBaseDendrite(t, dbType)
+ db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database)
+ if err != nil {
+ t.Fatalf("failed to create new database: %v", err)
+ }
+ return db, close
+}
+
+func MustNotError(t *testing.T, err error) {
+ t.Helper()
+ if err == nil {
+ return
+ }
+ t.Fatalf("operation failed: %s", err)
+}
+
+func TestKeyChanges(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ _, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
+ MustNotError(t, err)
+ deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
+ MustNotError(t, err)
+ userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
+ if err != nil {
+ t.Fatalf("Failed to KeyChanges: %s", err)
+ }
+ if latest != deviceChangeIDC {
+ t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
+ }
+ if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
+ t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
+ }
+ })
+}
+
+func TestKeyChangesNoDupes(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ if deviceChangeIDA == deviceChangeIDB {
+ t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
+ }
+ deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
+ if err != nil {
+ t.Fatalf("Failed to KeyChanges: %s", err)
+ }
+ if latest != deviceChangeID {
+ t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
+ }
+ if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
+ t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
+ }
+ })
+}
+
+func TestKeyChangesUpperLimit(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
+ MustNotError(t, err)
+ _, err = db.StoreKeyChange(ctx, "@charlie:localhost")
+ MustNotError(t, err)
+ userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
+ if err != nil {
+ t.Fatalf("Failed to KeyChanges: %s", err)
+ }
+ if latest != deviceChangeIDB {
+ t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
+ }
+ if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
+ t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
+ }
+ })
+}
+
+var dbLock sync.Mutex
+var deviceArray = []string{"AAA", "another_device"}
+
+// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
+// and that they are returned correctly when querying for device keys.
+func TestDeviceKeysStreamIDGeneration(t *testing.T) {
+ var err error
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ alice := "@alice:TestDeviceKeysStreamIDGeneration"
+ bob := "@bob:TestDeviceKeysStreamIDGeneration"
+ msgs := []api.DeviceMessage{
+ {
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 1
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: bob,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 1 as this is a different user
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: "another_device",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 2 as this is a 2nd device key
+ },
+ }
+ MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
+ if msgs[0].StreamID != 1 {
+ t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
+ }
+ if msgs[1].StreamID != 1 {
+ t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
+ }
+ if msgs[2].StreamID != 2 {
+ t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
+ }
+
+ // updating a device sets the next stream ID for that user
+ msgs = []api.DeviceMessage{
+ {
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v2"}`),
+ },
+ // StreamID: 3
+ },
+ }
+ MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
+ if msgs[0].StreamID != 3 {
+ t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
+ }
+
+ dbLock.Lock()
+ defer dbLock.Unlock()
+ // Querying for device keys returns the latest stream IDs
+ msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false)
+
+ if err != nil {
+ t.Fatalf("DeviceKeysForUser returned error: %s", err)
+ }
+ wantStreamIDs := map[string]int64{
+ "AAA": 3,
+ "another_device": 2,
+ }
+ if len(msgs) != len(wantStreamIDs) {
+ t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
+ }
+ for _, m := range msgs {
+ if m.StreamID != wantStreamIDs[m.DeviceID] {
+ t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
+ }
+ }
+ })
+}
diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go
index 5d5d292e..163e3e17 100644
--- a/userapi/storage/storage_wasm.go
+++ b/userapi/storage/storage_wasm.go
@@ -32,10 +32,10 @@ func NewUserAPIDatabase(
openIDTokenLifetimeMS int64,
loginTokenLifetime time.Duration,
serverNoticesLocalpart string,
-) (Database, error) {
+) (UserDatabase, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
+ return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go
index 9221e571..693e7303 100644
--- a/userapi/storage/tables/interface.go
+++ b/userapi/storage/tables/interface.go
@@ -20,10 +20,10 @@ import (
"encoding/json"
"time"
+ "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
- "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/types"
)
@@ -145,3 +145,47 @@ const (
// uint32.
AllNotifications NotificationFilter = (1 << 31) - 1
)
+
+type OneTimeKeys interface {
+ SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
+ CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
+ InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
+ // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
+ // Returns an empty map if the key does not exist.
+ SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
+ DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
+}
+
+type DeviceKeys interface {
+ SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
+ InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
+ SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
+ CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
+ SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
+ DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
+ DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
+}
+
+type KeyChanges interface {
+ InsertKeyChange(ctx context.Context, userID string) (int64, error)
+ // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
+ // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset.
+ SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
+}
+
+type StaleDeviceLists interface {
+ InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
+ SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
+ DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
+}
+
+type CrossSigningKeys interface {
+ SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error)
+ UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error
+}
+
+type CrossSigningSigs interface {
+ SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
+ UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
+ DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error
+}
diff --git a/userapi/storage/tables/stale_device_lists_test.go b/userapi/storage/tables/stale_device_lists_test.go
new file mode 100644
index 00000000..b9bdafda
--- /dev/null
+++ b/userapi/storage/tables/stale_device_lists_test.go
@@ -0,0 +1,94 @@
+package tables_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/matrix-org/dendrite/userapi/storage/postgres"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3"
+ "github.com/matrix-org/gomatrixserverlib"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/setup/config"
+
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := sqlutil.Open(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, nil)
+ if err != nil {
+ t.Fatalf("failed to open database: %s", err)
+ }
+ switch dbType {
+ case test.DBTypePostgres:
+ tab, err = postgres.NewPostgresStaleDeviceListsTable(db)
+ case test.DBTypeSQLite:
+ tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db)
+ }
+ if err != nil {
+ t.Fatalf("failed to create new table: %s", err)
+ }
+ return tab, close
+}
+
+func TestStaleDeviceLists(t *testing.T) {
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+ charlie := "@charlie:localhost"
+ ctx := context.Background()
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ tab, closeDB := mustCreateTable(t, dbType)
+ defer closeDB()
+
+ if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil {
+ t.Fatalf("failed to insert stale device: %s", err)
+ }
+ if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil {
+ t.Fatalf("failed to insert stale device: %s", err)
+ }
+ if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil {
+ t.Fatalf("failed to insert stale device: %s", err)
+ }
+
+ // Query one server
+ wantStaleUsers := []string{alice.ID, bob.ID}
+ gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
+ if err != nil {
+ t.Fatalf("failed to query stale device lists: %s", err)
+ }
+ if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
+ t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
+ }
+
+ // Query all servers
+ wantStaleUsers = []string{alice.ID, bob.ID, charlie}
+ gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{})
+ if err != nil {
+ t.Fatalf("failed to query stale device lists: %s", err)
+ }
+ if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
+ t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
+ }
+
+ // Delete stale devices
+ deleteUsers := []string{alice.ID, bob.ID}
+ if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil {
+ t.Fatalf("failed to delete stale device lists: %s", err)
+ }
+
+ // Verify we don't get anything back after deleting
+ gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
+ if err != nil {
+ t.Fatalf("failed to query stale device lists: %s", err)
+ }
+
+ if gotCount := len(gotStaleUsers); gotCount > 0 {
+ t.Fatalf("expected no stale users, got %d", gotCount)
+ }
+ })
+}
diff --git a/userapi/types/storage.go b/userapi/types/storage.go
new file mode 100644
index 00000000..7fb90454
--- /dev/null
+++ b/userapi/types/storage.go
@@ -0,0 +1,50 @@
+// Copyright 2021 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package types
+
+import (
+ "math"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const (
+ // OffsetNewest tells e.g. the database to get the most current data
+ OffsetNewest int64 = math.MaxInt64
+ // OffsetOldest tells e.g. the database to get the oldest data
+ OffsetOldest int64 = 0
+)
+
+// KeyTypePurposeToInt maps a purpose to an integer, which is used in the
+// database to reduce the amount of space taken up by this column.
+var KeyTypePurposeToInt = map[gomatrixserverlib.CrossSigningKeyPurpose]int16{
+ gomatrixserverlib.CrossSigningKeyPurposeMaster: 1,
+ gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: 2,
+ gomatrixserverlib.CrossSigningKeyPurposeUserSigning: 3,
+}
+
+// KeyTypeIntToPurpose maps an integer to a purpose, which is used in the
+// database to reduce the amount of space taken up by this column.
+var KeyTypeIntToPurpose = map[int16]gomatrixserverlib.CrossSigningKeyPurpose{
+ 1: gomatrixserverlib.CrossSigningKeyPurposeMaster,
+ 2: gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
+ 3: gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
+}
+
+// Map of purpose -> public key
+type CrossSigningKeyMap map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.Base64Bytes
+
+// Map of user ID -> key ID -> signature
+type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes
diff --git a/userapi/userapi.go b/userapi/userapi.go
index 2dd81d75..826bd721 100644
--- a/userapi/userapi.go
+++ b/userapi/userapi.go
@@ -17,13 +17,11 @@ package userapi
import (
"time"
+ fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/sirupsen/logrus"
- "github.com/matrix-org/dendrite/internal/pushgateway"
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
- "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/consumers"
@@ -33,16 +31,20 @@ import (
"github.com/matrix-org/dendrite/userapi/util"
)
-// NewInternalAPI returns a concerete implementation of the internal API. Callers
+// NewInternalAPI returns a concrete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
- base *base.BaseDendrite, cfg *config.UserAPI,
- appServices []config.ApplicationService, keyAPI keyapi.UserKeyAPI,
- rsAPI rsapi.UserRoomserverAPI, pgClient pushgateway.Client,
-) api.UserInternalAPI {
+ base *base.BaseDendrite,
+ rsAPI rsapi.UserRoomserverAPI,
+ fedClient fedsenderapi.KeyserverFederationAPI,
+) *internal.UserInternalAPI {
+ cfg := &base.Cfg.UserAPI
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
+ appServices := base.Cfg.Derived.ApplicationServices
- db, err := storage.NewUserAPIDatabase(
+ pgClient := base.PushGatewayHTTPClient()
+
+ db, err := storage.NewUserDatabase(
base,
&cfg.AccountDatabase,
cfg.Matrix.ServerName,
@@ -55,6 +57,11 @@ func NewInternalAPI(
logrus.WithError(err).Panicf("failed to connect to accounts db")
}
+ keyDB, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database)
+ if err != nil {
+ logrus.WithError(err).Panicf("failed to connect to key db")
+ }
+
syncProducer := producers.NewSyncAPI(
db, js,
// TODO: user API should handle syncs for account data. Right now,
@@ -64,17 +71,50 @@ func NewInternalAPI(
cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData),
cfg.Matrix.JetStream.Prefixed(jetstream.OutputNotificationData),
)
+ keyChangeProducer := &producers.KeyChange{
+ Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
+ JetStream: js,
+ DB: keyDB,
+ }
userAPI := &internal.UserInternalAPI{
DB: db,
+ KeyDatabase: keyDB,
SyncProducer: syncProducer,
+ KeyChangeProducer: keyChangeProducer,
Config: cfg,
AppServices: appServices,
- KeyAPI: keyAPI,
RSAPI: rsAPI,
DisableTLSValidation: cfg.PushGatewayDisableTLSValidation,
PgClient: pgClient,
- Cfg: cfg,
+ FedClient: fedClient,
+ }
+
+ updater := internal.NewDeviceListUpdater(base.ProcessContext, keyDB, userAPI, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable
+ userAPI.Updater = updater
+ // Remove users which we don't share a room with anymore
+ if err := updater.CleanUp(); err != nil {
+ logrus.WithError(err).Error("failed to cleanup stale device lists")
+ }
+
+ go func() {
+ if err := updater.Start(); err != nil {
+ logrus.WithError(err).Panicf("failed to start device list updater")
+ }
+ }()
+
+ dlConsumer := consumers.NewDeviceListUpdateConsumer(
+ base.ProcessContext, cfg, js, updater,
+ )
+ if err := dlConsumer.Start(); err != nil {
+ logrus.WithError(err).Panic("failed to start device list consumer")
+ }
+
+ sigConsumer := consumers.NewSigningKeyUpdateConsumer(
+ base.ProcessContext, cfg, js, userAPI,
+ )
+ if err := sigConsumer.Start(); err != nil {
+ logrus.WithError(err).Panic("failed to start signing key consumer")
}
receiptConsumer := consumers.NewOutputReceiptEventConsumer(
diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go
index 68d08c2f..08b1336b 100644
--- a/userapi/userapi_test.go
+++ b/userapi/userapi_test.go
@@ -21,7 +21,10 @@ import (
"testing"
"time"
+ "github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/nats-io/nats.go"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/setup/config"
@@ -38,32 +41,55 @@ const (
type apiTestOpts struct {
loginTokenLifetime time.Duration
+ serverName string
}
-func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) {
+type dummyProducer struct{}
+
+func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) {
+ return &nats.PubAck{}, nil
+}
+
+func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.UserDatabase, func()) {
if opts.loginTokenLifetime == 0 {
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
}
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType)
- accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
+ sName := serverName
+ if opts.serverName != "" {
+ sName = gomatrixserverlib.ServerName(opts.serverName)
+ }
+ accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
- }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
+ }, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
if err != nil {
t.Fatalf("failed to create account DB: %s", err)
}
+ keyDB, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ })
+ if err != nil {
+ t.Fatalf("failed to create key DB: %s", err)
+ }
+
cfg := &config.UserAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
- ServerName: serverName,
+ ServerName: sName,
},
},
}
+ syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "")
+ keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: &dummyProducer{}}
return &internal.UserInternalAPI{
- DB: accountDB,
- Config: cfg,
+ DB: accountDB,
+ KeyDatabase: keyDB,
+ Config: cfg,
+ SyncProducer: syncProducer,
+ KeyChangeProducer: keyChangeProducer,
}, accountDB, func() {
close()
baseclose()
@@ -332,3 +358,292 @@ func TestQueryAccountByLocalpart(t *testing.T) {
testCases(t, intAPI)
})
}
+
+func TestAccountData(t *testing.T) {
+ ctx := context.Background()
+ alice := test.NewUser(t)
+
+ testCases := []struct {
+ name string
+ inputData *api.InputAccountDataRequest
+ wantErr bool
+ }{
+ {
+ name: "not a local user",
+ inputData: &api.InputAccountDataRequest{UserID: "@notlocal:example.com"},
+ wantErr: true,
+ },
+ {
+ name: "local user missing datatype",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID},
+ wantErr: true,
+ },
+ {
+ name: "missing json",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: nil},
+ wantErr: true,
+ },
+ {
+ name: "with json",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}")},
+ },
+ {
+ name: "room data",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}"), RoomID: "!dummy:test"},
+ },
+ {
+ name: "ignored users",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.ignored_user_list", AccountData: []byte("{}")},
+ },
+ {
+ name: "m.fully_read",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.fully_read", AccountData: []byte("{}")},
+ },
+ }
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
+ defer close()
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ res := api.InputAccountDataResponse{}
+ err := intAPI.InputAccountData(ctx, tc.inputData, &res)
+ if tc.wantErr && err == nil {
+ t.Fatalf("expected an error, but got none")
+ }
+ if !tc.wantErr && err != nil {
+ t.Fatalf("expected no error, but got: %s", err)
+ }
+
+ // query the data again and compare
+ queryRes := api.QueryAccountDataResponse{}
+ queryReq := api.QueryAccountDataRequest{
+ UserID: tc.inputData.UserID,
+ DataType: tc.inputData.DataType,
+ RoomID: tc.inputData.RoomID,
+ }
+ err = intAPI.QueryAccountData(ctx, &queryReq, &queryRes)
+ if err != nil && !tc.wantErr {
+ t.Fatal(err)
+ }
+ // verify global data
+ if tc.inputData.RoomID == "" {
+ if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.GlobalAccountData[tc.inputData.DataType]) {
+ t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.GlobalAccountData[tc.inputData.DataType]))
+ }
+ } else {
+ // verify room data
+ if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]) {
+ t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]))
+ }
+ }
+ })
+ }
+ })
+}
+
+func TestDevices(t *testing.T) {
+ ctx := context.Background()
+
+ dupeAccessToken := util.RandomString(8)
+
+ displayName := "testing"
+
+ creationTests := []struct {
+ name string
+ inputData *api.PerformDeviceCreationRequest
+ wantErr bool
+ wantNewDevID bool
+ }{
+ {
+ name: "not a local user",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal"},
+ wantErr: true,
+ },
+ {
+ name: "implicit local user",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, DeviceDisplayName: &displayName},
+ },
+ {
+ name: "explicit local user",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
+ },
+ {
+ name: "dupe token - ok",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true},
+ },
+ {
+ name: "dupe token - not ok",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true},
+ wantErr: true,
+ },
+ {
+ name: "test3 second device", // used to test deletion later
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
+ },
+ {
+ name: "test3 third device", // used to test deletion later
+ wantNewDevID: true,
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
+ },
+ }
+
+ deletionTests := []struct {
+ name string
+ inputData *api.PerformDeviceDeletionRequest
+ wantErr bool
+ wantDevices int
+ }{
+ {
+ name: "deletion - not a local user",
+ inputData: &api.PerformDeviceDeletionRequest{UserID: "@test:notlocalhost"},
+ wantErr: true,
+ },
+ {
+ name: "deleting not existing devices should not error",
+ inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test", DeviceIDs: []string{"iDontExist"}},
+ wantDevices: 1,
+ },
+ {
+ name: "delete all devices",
+ inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test"},
+ wantDevices: 0,
+ },
+ {
+ name: "delete all devices",
+ inputData: &api.PerformDeviceDeletionRequest{UserID: "@test3:test"},
+ wantDevices: 0,
+ },
+ }
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
+ defer close()
+
+ for _, tc := range creationTests {
+ t.Run(tc.name, func(t *testing.T) {
+ res := api.PerformDeviceCreationResponse{}
+ deviceID := util.RandomString(8)
+ tc.inputData.DeviceID = &deviceID
+ if tc.wantNewDevID {
+ tc.inputData.DeviceID = nil
+ }
+ err := intAPI.PerformDeviceCreation(ctx, tc.inputData, &res)
+ if tc.wantErr && err == nil {
+ t.Fatalf("expected an error, but got none")
+ }
+ if !tc.wantErr && err != nil {
+ t.Fatalf("expected no error, but got: %s", err)
+ }
+ if !res.DeviceCreated {
+ return
+ }
+
+ queryDevicesRes := api.QueryDevicesResponse{}
+ queryDevicesReq := api.QueryDevicesRequest{UserID: res.Device.UserID}
+ if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil {
+ t.Fatal(err)
+ }
+ // We only want to verify one device
+ if len(queryDevicesRes.Devices) > 1 {
+ return
+ }
+ res.Device.AccessToken = ""
+
+ // At this point, there should only be one device
+ if !reflect.DeepEqual(*res.Device, queryDevicesRes.Devices[0]) {
+ t.Fatalf("expected device to be\n%#v, got \n%#v", *res.Device, queryDevicesRes.Devices[0])
+ }
+
+ newDisplayName := "new name"
+ if tc.inputData.DeviceDisplayName == nil {
+ updateRes := api.PerformDeviceUpdateResponse{}
+ updateReq := api.PerformDeviceUpdateRequest{
+ RequestingUserID: fmt.Sprintf("@%s:%s", tc.inputData.Localpart, "test"),
+ DeviceID: deviceID,
+ DisplayName: &newDisplayName,
+ }
+
+ if err = intAPI.PerformDeviceUpdate(ctx, &updateReq, &updateRes); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ queryDeviceInfosRes := api.QueryDeviceInfosResponse{}
+ queryDeviceInfosReq := api.QueryDeviceInfosRequest{DeviceIDs: []string{*tc.inputData.DeviceID}}
+ if err = intAPI.QueryDeviceInfos(ctx, &queryDeviceInfosReq, &queryDeviceInfosRes); err != nil {
+ t.Fatal(err)
+ }
+ gotDisplayName := queryDeviceInfosRes.DeviceInfo[*tc.inputData.DeviceID].DisplayName
+ if tc.inputData.DeviceDisplayName != nil {
+ wantDisplayName := *tc.inputData.DeviceDisplayName
+ if wantDisplayName != gotDisplayName {
+ t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName)
+ }
+ } else {
+ wantDisplayName := newDisplayName
+ if wantDisplayName != gotDisplayName {
+ t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName)
+ }
+ }
+ })
+ }
+
+ for _, tc := range deletionTests {
+ t.Run(tc.name, func(t *testing.T) {
+ delRes := api.PerformDeviceDeletionResponse{}
+ err := intAPI.PerformDeviceDeletion(ctx, tc.inputData, &delRes)
+ if tc.wantErr && err == nil {
+ t.Fatalf("expected an error, but got none")
+ }
+ if !tc.wantErr && err != nil {
+ t.Fatalf("expected no error, but got: %s", err)
+ }
+ if tc.wantErr {
+ return
+ }
+
+ queryDevicesRes := api.QueryDevicesResponse{}
+ queryDevicesReq := api.QueryDevicesRequest{UserID: tc.inputData.UserID}
+ if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil {
+ t.Fatal(err)
+ }
+
+ if len(queryDevicesRes.Devices) != tc.wantDevices {
+ t.Fatalf("expected %d devices, got %d", tc.wantDevices, len(queryDevicesRes.Devices))
+ }
+
+ })
+ }
+ })
+}
+
+// Tests that the session ID of a device is not reused when reusing the same device ID.
+func TestDeviceIDReuse(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
+ defer close()
+
+ res := api.PerformDeviceCreationResponse{}
+ // create a first device
+ deviceID := util.RandomString(8)
+ req := api.PerformDeviceCreationRequest{Localpart: "alice", ServerName: "test", DeviceID: &deviceID, NoDeviceListUpdate: true}
+ err := intAPI.PerformDeviceCreation(ctx, &req, &res)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Do the same request again, we expect a different sessionID
+ res2 := api.PerformDeviceCreationResponse{}
+ err = intAPI.PerformDeviceCreation(ctx, &req, &res2)
+ if err != nil {
+ t.Fatalf("expected no error, but got: %v", err)
+ }
+
+ if res2.Device.SessionID == res.Device.SessionID {
+ t.Fatalf("expected a different session ID, but they are the same")
+ }
+ })
+}
diff --git a/userapi/util/devices.go b/userapi/util/devices.go
index c55fc799..31617d8c 100644
--- a/userapi/util/devices.go
+++ b/userapi/util/devices.go
@@ -19,7 +19,7 @@ type PusherDevice struct {
}
// GetPushDevices pushes to the configured devices of a local user.
-func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
+func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) {
pushers, err := db.GetPushers(ctx, localpart, serverName)
if err != nil {
return nil, fmt.Errorf("db.GetPushers: %w", err)
diff --git a/userapi/util/notify.go b/userapi/util/notify.go
index fc0ab39b..08d1371d 100644
--- a/userapi/util/notify.go
+++ b/userapi/util/notify.go
@@ -13,11 +13,11 @@ import (
)
// NotifyUserCountsAsync sends notifications to a local user's
-// notification destinations. Database lookups run synchronously, but
+// notification destinations. UserDatabase lookups run synchronously, but
// a single goroutine is started when talking to the Push
// gateways. There is no way to know when the background goroutine has
// finished.
-func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error {
+func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.UserDatabase) error {
pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db)
if err != nil {
return err
diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go
index f1d20259..421852d3 100644
--- a/userapi/util/notify_test.go
+++ b/userapi/util/notify_test.go
@@ -79,7 +79,7 @@ func TestNotifyUserCountsAsync(t *testing.T) {
defer close()
base, _, _ := testrig.Base(nil)
defer base.Close()
- db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
+ db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "test", bcrypt.MinCost, 0, 0, "")
if err != nil {
diff --git a/userapi/util/phonehomestats_test.go b/userapi/util/phonehomestats_test.go
index 6e62210e..5f626b5b 100644
--- a/userapi/util/phonehomestats_test.go
+++ b/userapi/util/phonehomestats_test.go
@@ -21,7 +21,7 @@ func TestCollect(t *testing.T) {
b, _, _ := testrig.Base(nil)
connStr, closeDB := test.PrepareDBConnectionString(t, dbType)
defer closeDB()
- db, err := storage.NewUserAPIDatabase(b, &config.DatabaseOptions{
+ db, err := storage.NewUserDatabase(b, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "localhost", bcrypt.MinCost, 1000, 1000, "")
if err != nil {