diff options
author | Till <2353100+S7evinK@users.noreply.github.com> | 2023-02-20 14:58:03 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-20 14:58:03 +0100 |
commit | 4594233f89f8531fca8f696ab0ece36909130c2a (patch) | |
tree | 18d3c451041423022e15ba5fcc4a778806ff94dc /userapi | |
parent | bd6f0c14e56af71d83d703b7c91b8cf829ca560f (diff) |
Merge keyserver & userapi (#2972)
As discussed yesterday, a first draft of merging the keyserver and the
userapi.
Diffstat (limited to 'userapi')
54 files changed, 6550 insertions, 90 deletions
diff --git a/userapi/api/api.go b/userapi/api/api.go index 4ea2e91c..fa297f77 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -15,9 +15,13 @@ package api import ( + "bytes" "context" "encoding/json" + "strings" + "time" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -26,15 +30,12 @@ import ( // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { - AppserviceUserAPI SyncUserAPI ClientUserAPI - MediaUserAPI FederationUserAPI - RoomserverUserAPI - KeyserverUserAPI QuerySearchProfilesAPI // used by p2p demos + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) } // api functions required by the appservice api @@ -43,11 +44,6 @@ type AppserviceUserAPI interface { PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error } -type KeyserverUserAPI interface { - QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error - QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error -} - type RoomserverUserAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) @@ -60,13 +56,20 @@ type MediaUserAPI interface { // api functions required by the federation api type FederationUserAPI interface { + UploadDeviceKeysAPI QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error + QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error } // api functions required by the sync api type SyncUserAPI interface { QueryAcccessTokenAPI + SyncKeyAPI QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error @@ -79,6 +82,7 @@ type ClientUserAPI interface { QueryAcccessTokenAPI LoginTokenInternalAPI UserLoginAPI + ClientKeyAPI QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error @@ -681,3 +685,310 @@ type QueryAccountByLocalpartRequest struct { type QueryAccountByLocalpartResponse struct { Account *Account } + +// API functions required by the clientapi +type ClientKeyAPI interface { + UploadDeviceKeysAPI + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error + + PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error + // PerformClaimKeys claims one-time keys for use in pre-key messages + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error + PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error +} + +type UploadDeviceKeysAPI interface { + PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error +} + +// API functions required by the syncapi +type SyncKeyAPI interface { + QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error + QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error + PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error +} + +type FederationKeyAPI interface { + UploadDeviceKeysAPI + QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error + QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error + QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error + PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error +} + +// KeyError is returned if there was a problem performing/querying the server +type KeyError struct { + Err string `json:"error"` + IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE + IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM + IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM +} + +func (k *KeyError) Error() string { + return k.Err +} + +type DeviceMessageType int + +const ( + TypeDeviceKeyUpdate DeviceMessageType = iota + TypeCrossSigningUpdate +) + +// DeviceMessage represents the message produced into Kafka by the key server. +type DeviceMessage struct { + Type DeviceMessageType `json:"Type,omitempty"` + *DeviceKeys `json:"DeviceKeys,omitempty"` + *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` + // A monotonically increasing number which represents device changes for this user. + StreamID int64 + DeviceChangeID int64 +} + +// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log +type OutputCrossSigningKeyUpdate struct { + CrossSigningKeyUpdate `json:"signing_keys"` +} + +type CrossSigningKeyUpdate struct { + MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"` + SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"` + UserID string `json:"user_id"` +} + +// DeviceKeysEqual returns true if the device keys updates contain the +// same display name and key JSON. This will return false if either of +// the updates is not a device keys update, or if the user ID/device ID +// differ between the two. +func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool { + if m1.DeviceKeys == nil || m2.DeviceKeys == nil { + return false + } + if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID { + return false + } + if m1.DisplayName != m2.DisplayName { + return false // different display names + } + if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 { + return false // either is empty + } + return bytes.Equal(m1.KeyJSON, m2.KeyJSON) +} + +// DeviceKeys represents a set of device keys for a single device +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload +type DeviceKeys struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // The device display name + DisplayName string + // The raw device key JSON + KeyJSON []byte +} + +// WithStreamID returns a copy of this device message with the given stream ID +func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage { + return DeviceMessage{ + DeviceKeys: k, + StreamID: streamID, + } +} + +// OneTimeKeys represents a set of one-time keys for a single device +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload +type OneTimeKeys struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // A map of algorithm:key_id => key JSON + KeyJSON map[string]json.RawMessage +} + +// Split a key in KeyJSON into algorithm and key ID +func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { + segments := strings.Split(keyIDWithAlgo, ":") + return segments[0], segments[1] +} + +// OneTimeKeysCount represents the counts of one-time keys for a single device +type OneTimeKeysCount struct { + // The user who owns this device + UserID string + // The device ID of this device + DeviceID string + // algorithm to count e.g: + // { + // "curve25519": 10, + // "signed_curve25519": 20 + // } + KeyCount map[string]int +} + +// PerformUploadKeysRequest is the request to PerformUploadKeys +type PerformUploadKeysRequest struct { + UserID string // Required - User performing the request + DeviceID string // Optional - Device performing the request, for fetching OTK count + DeviceKeys []DeviceKeys + OneTimeKeys []OneTimeKeys + // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update + // the display name for their respective device, and NOT to modify the keys. The key + // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. + // Without this flag, requests to modify device display names would delete device keys. + OnlyDisplayNameUpdates bool +} + +// PerformUploadKeysResponse is the response to PerformUploadKeys +type PerformUploadKeysResponse struct { + // A fatal error when processing e.g database failures + Error *KeyError + // A map of user_id -> device_id -> Error for tracking failures. + KeyErrors map[string]map[string]*KeyError + OneTimeKeyCounts []OneTimeKeysCount +} + +// PerformDeleteKeysRequest asks the keyserver to forget about certain +// keys, and signatures related to those keys. +type PerformDeleteKeysRequest struct { + UserID string + KeyIDs []gomatrixserverlib.KeyID +} + +// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest. +type PerformDeleteKeysResponse struct { + Error *KeyError +} + +// KeyError sets a key error field on KeyErrors +func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) { + if r.KeyErrors[userID] == nil { + r.KeyErrors[userID] = make(map[string]*KeyError) + } + r.KeyErrors[userID][deviceID] = err +} + +type PerformClaimKeysRequest struct { + // Map of user_id to device_id to algorithm name + OneTimeKeys map[string]map[string]string + Timeout time.Duration +} + +type PerformClaimKeysResponse struct { + // Map of user_id to device_id to algorithm:key_id to key JSON + OneTimeKeys map[string]map[string]map[string]json.RawMessage + // Map of remote server domain to error JSON + Failures map[string]interface{} + // Set if there was a fatal error processing this action + Error *KeyError +} + +type PerformUploadDeviceKeysRequest struct { + gomatrixserverlib.CrossSigningKeys + // The user that uploaded the key, should be populated by the clientapi. + UserID string +} + +type PerformUploadDeviceKeysResponse struct { + Error *KeyError +} + +type PerformUploadDeviceSignaturesRequest struct { + Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice + // The user that uploaded the sig, should be populated by the clientapi. + UserID string +} + +type PerformUploadDeviceSignaturesResponse struct { + Error *KeyError +} + +type QueryKeysRequest struct { + // The user ID asking for the keys, e.g. if from a client API request. + // Will not be populated if the key request came from federation. + UserID string + // Maps user IDs to a list of devices + UserToDevices map[string][]string + Timeout time.Duration +} + +type QueryKeysResponse struct { + // Map of remote server domain to error JSON + Failures map[string]interface{} + // Map of user_id to device_id to device_key + DeviceKeys map[string]map[string]json.RawMessage + // Maps of user_id to cross signing key + MasterKeys map[string]gomatrixserverlib.CrossSigningKey + SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey + UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // Set if there was a fatal error processing this query + Error *KeyError +} + +type QueryKeyChangesRequest struct { + // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning + Offset int64 + // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. + // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). + ToOffset int64 +} + +type QueryKeyChangesResponse struct { + // The set of users who have had their keys change. + UserIDs []string + // The latest offset represented in this response. + Offset int64 + // Set if there was a problem handling the request. + Error *KeyError +} + +type QueryOneTimeKeysRequest struct { + // The local user to query OTK counts for + UserID string + // The device to query OTK counts for + DeviceID string +} + +type QueryOneTimeKeysResponse struct { + // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 + Count OneTimeKeysCount + Error *KeyError +} + +type QueryDeviceMessagesRequest struct { + UserID string +} + +type QueryDeviceMessagesResponse struct { + // The latest stream ID + StreamID int64 + Devices []DeviceMessage + Error *KeyError +} + +type QuerySignaturesRequest struct { + // A map of target user ID -> target key/device IDs to retrieve signatures for + TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"` +} + +type QuerySignaturesResponse struct { + // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures + Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap + // A map of target user ID -> cross-signing master key + MasterKeys map[string]gomatrixserverlib.CrossSigningKey + // A map of target user ID -> cross-signing self-signing key + SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // A map of target user ID -> cross-signing user-signing key + UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey + // The request error, if any + Error *KeyError +} + +type PerformMarkAsStaleRequest struct { + UserID string + Domain gomatrixserverlib.ServerName + DeviceID string +} diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index 42ae72e7..51bd2753 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -37,7 +37,7 @@ type OutputReceiptEventConsumer struct { jetstream nats.JetStreamContext durable string topic string - db storage.Database + db storage.UserDatabase serverName gomatrixserverlib.ServerName syncProducer *producers.SyncAPI pgClient pushgateway.Client @@ -49,7 +49,7 @@ func NewOutputReceiptEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, - store storage.Database, + store storage.UserDatabase, syncProducer *producers.SyncAPI, pgClient pushgateway.Client, ) *OutputReceiptEventConsumer { diff --git a/userapi/consumers/devicelistupdate.go b/userapi/consumers/devicelistupdate.go new file mode 100644 index 00000000..a65889fc --- /dev/null +++ b/userapi/consumers/devicelistupdate.go @@ -0,0 +1,95 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/userapi/internal" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" +) + +// DeviceListUpdateConsumer consumes device list updates that came in over federation. +type DeviceListUpdateConsumer struct { + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + updater *internal.DeviceListUpdater + isLocalServerName func(gomatrixserverlib.ServerName) bool +} + +// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. +func NewDeviceListUpdateConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + updater *internal.DeviceListUpdater, +) *DeviceListUpdateConsumer { + return &DeviceListUpdateConsumer{ + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + updater: updater, + isLocalServerName: cfg.Matrix.IsLocalServerName, + } +} + +// Start consuming from key servers +func (t *DeviceListUpdateConsumer) Start() error { + return jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.topic, t.durable, 1, + t.onMessage, nats.DeliverAll(), nats.ManualAck(), + ) +} + +// onMessage is called in response to a message received on the +// key change events topic from the key server. +func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + var m gomatrixserverlib.DeviceListUpdateEvent + if err := json.Unmarshal(msg.Data, &m); err != nil { + logrus.WithError(err).Errorf("Failed to read from device list update input topic") + return true + } + origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) + if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { + return true + } else if t.isLocalServerName(serverName) { + return true + } else if serverName != origin { + return true + } + + err := t.updater.Update(ctx, m) + if err != nil { + logrus.WithFields(logrus.Fields{ + "user_id": m.UserID, + "device_id": m.DeviceID, + "stream_id": m.StreamID, + "prev_id": m.PrevID, + }).WithError(err).Errorf("Failed to update device list") + return false + } + return true +} diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 3ce5af62..47d33095 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct { rsAPI rsapi.UserRoomserverAPI jetstream nats.JetStreamContext durable string - db storage.Database + db storage.UserDatabase topic string pgClient pushgateway.Client syncProducer *producers.SyncAPI @@ -53,7 +53,7 @@ func NewOutputRoomEventConsumer( process *process.ProcessContext, cfg *config.UserAPI, js nats.JetStreamContext, - store storage.Database, + store storage.UserDatabase, pgClient pushgateway.Client, rsAPI rsapi.UserRoomserverAPI, syncProducer *producers.SyncAPI, diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 39f4aab4..bc5ae652 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -18,11 +18,11 @@ import ( userAPITypes "github.com/matrix-org/dendrite/userapi/types" ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) t.Helper() connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "", 4, 0, 0, "") if err != nil { diff --git a/userapi/consumers/signingkeyupdate.go b/userapi/consumers/signingkeyupdate.go new file mode 100644 index 00000000..f4ff017d --- /dev/null +++ b/userapi/consumers/signingkeyupdate.go @@ -0,0 +1,111 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" +) + +// SigningKeyUpdateConsumer consumes signing key updates that came in over federation. +type SigningKeyUpdateConsumer struct { + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + userAPI api.UploadDeviceKeysAPI + cfg *config.UserAPI + isLocalServerName func(gomatrixserverlib.ServerName) bool +} + +// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. +func NewSigningKeyUpdateConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + userAPI api.UploadDeviceKeysAPI, +) *SigningKeyUpdateConsumer { + return &SigningKeyUpdateConsumer{ + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + userAPI: userAPI, + cfg: cfg, + isLocalServerName: cfg.Matrix.IsLocalServerName, + } +} + +// Start consuming from key servers +func (t *SigningKeyUpdateConsumer) Start() error { + return jetstream.JetStreamConsumer( + t.ctx, t.jetstream, t.topic, t.durable, 1, + t.onMessage, nats.DeliverAll(), nats.ManualAck(), + ) +} + +// onMessage is called in response to a message received on the +// signing key update events topic from the key server. +func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + var updatePayload api.CrossSigningKeyUpdate + if err := json.Unmarshal(msg.Data, &updatePayload); err != nil { + logrus.WithError(err).Errorf("Failed to read from signing key update input topic") + return true + } + origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) + if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { + logrus.WithError(err).Error("failed to split user id") + return true + } else if t.isLocalServerName(serverName) { + logrus.Warn("dropping device key update from ourself") + return true + } else if serverName != origin { + logrus.Warnf("dropping device key update, %s != %s", serverName, origin) + return true + } + + keys := gomatrixserverlib.CrossSigningKeys{} + if updatePayload.MasterKey != nil { + keys.MasterKey = *updatePayload.MasterKey + } + if updatePayload.SelfSigningKey != nil { + keys.SelfSigningKey = *updatePayload.SelfSigningKey + } + uploadReq := &api.PerformUploadDeviceKeysRequest{ + CrossSigningKeys: keys, + UserID: updatePayload.UserID, + } + uploadRes := &api.PerformUploadDeviceKeysResponse{} + if err := t.userAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { + logrus.WithError(err).Error("failed to upload device keys") + return false + } + if uploadRes.Error != nil { + logrus.WithError(uploadRes.Error).Error("failed to upload device keys") + return true + } + + return true +} diff --git a/userapi/internal/cross_signing.go b/userapi/internal/cross_signing.go new file mode 100644 index 00000000..8b9704d1 --- /dev/null +++ b/userapi/internal/cross_signing.go @@ -0,0 +1,587 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "bytes" + "context" + "crypto/ed25519" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "golang.org/x/crypto/curve25519" +) + +func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpose gomatrixserverlib.CrossSigningKeyPurpose) error { + // Is there exactly one key? + if len(key.Keys) != 1 { + return fmt.Errorf("should contain exactly one key") + } + + // Does the key ID match the key value? Iterates exactly once + for keyID, keyData := range key.Keys { + b64 := keyData.Encode() + tokens := strings.Split(string(keyID), ":") + if len(tokens) != 2 { + return fmt.Errorf("key ID is incorrectly formatted") + } + if tokens[1] != b64 { + return fmt.Errorf("key ID isn't correct") + } + switch tokens[0] { + case "ed25519": + if len(keyData) != ed25519.PublicKeySize { + return fmt.Errorf("ed25519 key is not the correct length") + } + case "curve25519": + if len(keyData) != curve25519.PointSize { + return fmt.Errorf("curve25519 key is not the correct length") + } + default: + // We can't enforce the key length to be correct for an + // algorithm that we don't recognise, so instead we'll + // just make sure that it isn't incredibly excessive. + if l := len(keyData); l > 4096 { + return fmt.Errorf("unknown key type is too long (%d bytes)", l) + } + } + } + + // Check to see if the signatures make sense + for _, forOriginUser := range key.Signatures { + for originKeyID, originSignature := range forOriginUser { + switch strings.SplitN(string(originKeyID), ":", 1)[0] { + case "ed25519": + if len(originSignature) != ed25519.SignatureSize { + return fmt.Errorf("ed25519 signature is not the correct length") + } + case "curve25519": + return fmt.Errorf("curve25519 signatures are impossible") + default: + if l := len(originSignature); l > 4096 { + return fmt.Errorf("unknown signature type is too long (%d bytes)", l) + } + } + } + } + + // Does the key claim to be from the right user? + if userID != key.UserID { + return fmt.Errorf("key has a user ID mismatch") + } + + // Does the key contain the correct purpose? + useful := false + for _, usage := range key.Usage { + if usage == purpose { + useful = true + break + } + } + if !useful { + return fmt.Errorf("key does not contain correct usage purpose") + } + + return nil +} + +// nolint:gocyclo +func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { + // Find the keys to store. + byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} + toStore := types.CrossSigningKeyMap{} + hasMasterKey := false + + if len(req.MasterKey.Keys) > 0 { + if err := sanityCheckKey(req.MasterKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err != nil { + res.Error = &api.KeyError{ + Err: "Master key sanity check failed: " + err.Error(), + IsInvalidParam: true, + } + return nil + } + + byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey + for _, key := range req.MasterKey.Keys { // iterates once, see sanityCheckKey + toStore[gomatrixserverlib.CrossSigningKeyPurposeMaster] = key + } + hasMasterKey = true + } + + if len(req.SelfSigningKey.Keys) > 0 { + if err := sanityCheckKey(req.SelfSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err != nil { + res.Error = &api.KeyError{ + Err: "Self-signing key sanity check failed: " + err.Error(), + IsInvalidParam: true, + } + return nil + } + + byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey + for _, key := range req.SelfSigningKey.Keys { // iterates once, see sanityCheckKey + toStore[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = key + } + } + + if len(req.UserSigningKey.Keys) > 0 { + if err := sanityCheckKey(req.UserSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeUserSigning); err != nil { + res.Error = &api.KeyError{ + Err: "User-signing key sanity check failed: " + err.Error(), + IsInvalidParam: true, + } + return nil + } + + byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey + for _, key := range req.UserSigningKey.Keys { // iterates once, see sanityCheckKey + toStore[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = key + } + } + + // If there's nothing to do then stop here. + if len(toStore) == 0 { + res.Error = &api.KeyError{ + Err: "No keys were supplied in the request", + IsMissingParam: true, + } + return nil + } + + // We can't have a self-signing or user-signing key without a master + // key, so make sure we have one of those. We will also only actually do + // something if any of the specified keys in the request are different + // to what we've got in the database, to avoid generating key change + // notifications unnecessarily. + existingKeys, err := a.KeyDatabase.CrossSigningKeysDataForUser(ctx, req.UserID) + if err != nil { + res.Error = &api.KeyError{ + Err: "Retrieving cross-signing keys from database failed: " + err.Error(), + } + return nil + } + + // If we still can't find a master key for the user then stop the upload. + // This satisfies the "Fails to upload self-signing key without master key" test. + if !hasMasterKey { + if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey { + res.Error = &api.KeyError{ + Err: "No master key was found", + IsMissingParam: true, + } + return nil + } + } + + // Check if anything actually changed compared to what we have in the database. + changed := false + for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{ + gomatrixserverlib.CrossSigningKeyPurposeMaster, + gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, + gomatrixserverlib.CrossSigningKeyPurposeUserSigning, + } { + old, gotOld := existingKeys[purpose] + new, gotNew := toStore[purpose] + if gotOld != gotNew { + // A new key purpose has been specified that we didn't know before, + // or one has been removed. + changed = true + break + } + if !bytes.Equal(old, new) { + // One of the existing keys for a purpose we already knew about has + // changed. + changed = true + break + } + } + if !changed { + return nil + } + + // Store the keys. + if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), + } + return nil + } + + // Now upload any signatures that were included with the keys. + for _, key := range byPurpose { + var targetKeyID gomatrixserverlib.KeyID + for targetKey := range key.Keys { // iterates once, see sanityCheckKey + targetKeyID = targetKey + } + for sigUserID, forSigUserID := range key.Signatures { + if sigUserID != req.UserID { + continue + } + for sigKeyID, sigBytes := range forSigUserID { + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err), + } + return nil + } + } + } + } + + // Finally, generate a notification that we updated the keys. + update := api.CrossSigningKeyUpdate{ + UserID: req.UserID, + } + if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok { + update.MasterKey = &mk + } + if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok { + update.SelfSigningKey = &ssk + } + if update.MasterKey == nil && update.SelfSigningKey == nil { + return nil + } + if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), + } + return nil + } + return nil +} + +func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { + // Before we do anything, we need the master and self-signing keys for this user. + // Then we can verify the signatures make sense. + queryReq := &api.QueryKeysRequest{ + UserID: req.UserID, + UserToDevices: map[string][]string{}, + } + queryRes := &api.QueryKeysResponse{} + for userID := range req.Signatures { + queryReq.UserToDevices[userID] = []string{} + } + _ = a.QueryKeys(ctx, queryReq, queryRes) + + selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + + // Sort signatures into two groups: one where people have signed their own + // keys and one where people have signed someone elses + for userID, forUserID := range req.Signatures { + for keyID, keyOrDevice := range forUserID { + switch key := keyOrDevice.CrossSigningBody.(type) { + case *gomatrixserverlib.CrossSigningKey: + if key.UserID == req.UserID { + if _, ok := selfSignatures[userID]; !ok { + selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + } + selfSignatures[userID][keyID] = keyOrDevice + } else { + if _, ok := otherSignatures[userID]; !ok { + otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + } + otherSignatures[userID][keyID] = keyOrDevice + } + + case *gomatrixserverlib.DeviceKeys: + if key.UserID == req.UserID { + if _, ok := selfSignatures[userID]; !ok { + selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + } + selfSignatures[userID][keyID] = keyOrDevice + } else { + if _, ok := otherSignatures[userID]; !ok { + otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} + } + otherSignatures[userID][keyID] = keyOrDevice + } + + default: + continue + } + } + } + + if err := a.processSelfSignatures(ctx, selfSignatures); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.processSelfSignatures: %s", err), + } + return nil + } + + if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.processOtherSignatures: %s", err), + } + return nil + } + + // Finally, generate a notification that we updated the signatures. + for userID := range req.Signatures { + masterKey := queryRes.MasterKeys[userID] + selfSigningKey := queryRes.SelfSigningKeys[userID] + update := api.CrossSigningKeyUpdate{ + UserID: userID, + MasterKey: &masterKey, + SelfSigningKey: &selfSigningKey, + } + if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), + } + return nil + } + } + return nil +} + +func (a *UserInternalAPI) processSelfSignatures( + ctx context.Context, + signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, +) error { + // Here we will process: + // * The user signing their own devices using their self-signing key + // * The user signing their master key using one of their devices + + for targetUserID, forTargetUserID := range signatures { + for targetKeyID, signature := range forTargetUserID { + switch sig := signature.CrossSigningBody.(type) { + case *gomatrixserverlib.CrossSigningKey: + for keyID := range sig.Keys { + split := strings.SplitN(string(keyID), ":", 2) + if len(split) > 1 && gomatrixserverlib.KeyID(split[1]) == targetKeyID { + targetKeyID = keyID // contains the ed25519: or other scheme + break + } + } + for originUserID, forOriginUserID := range sig.Signatures { + for originKeyID, originSig := range forOriginUserID { + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( + ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, + ); err != nil { + return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) + } + } + } + + case *gomatrixserverlib.DeviceKeys: + for originUserID, forOriginUserID := range sig.Signatures { + for originKeyID, originSig := range forOriginUserID { + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( + ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, + ); err != nil { + return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) + } + } + } + + default: + return fmt.Errorf("unexpected type assertion") + } + } + } + + return nil +} + +func (a *UserInternalAPI) processOtherSignatures( + ctx context.Context, userID string, queryRes *api.QueryKeysResponse, + signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, +) error { + // Here we will process: + // * A user signing someone else's master keys using their user-signing keys + + for targetUserID, forTargetUserID := range signatures { + for _, signature := range forTargetUserID { + switch sig := signature.CrossSigningBody.(type) { + case *gomatrixserverlib.CrossSigningKey: + // Find the local copy of the master key. We'll use this to be + // sure that the supplied stanza matches the key that we think it + // should be. + masterKey, ok := queryRes.MasterKeys[targetUserID] + if !ok { + return fmt.Errorf("failed to find master key for user %q", targetUserID) + } + + // For each key ID, write the signatures. Maybe there'll be more + // than one algorithm in the future so it's best not to focus on + // everything being ed25519:. + for targetKeyID, suppliedKeyData := range sig.Keys { + // The master key will be supplied in the request, but we should + // make sure that it matches what we think the master key should + // actually be. + localKeyData, lok := masterKey.Keys[targetKeyID] + if !lok { + return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID) + } else if !bytes.Equal(suppliedKeyData, localKeyData) { + return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID) + } + + // We only care about the signatures from the uploading user, so + // we will ignore anything that didn't originate from them. + userSigs, ok := sig.Signatures[userID] + if !ok { + return fmt.Errorf("there are no signatures on master key %q from uploading user %q", targetKeyID, userID) + } + + for originKeyID, originSig := range userSigs { + if err := a.KeyDatabase.StoreCrossSigningSigsForTarget( + ctx, userID, originKeyID, targetUserID, targetKeyID, originSig, + ); err != nil { + return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) + } + } + } + + default: + // Users should only be signing another person's master key, + // so if we're here, it's probably because it's actually a + // gomatrixserverlib.DeviceKeys, which doesn't make sense. + } + } + } + + return nil +} + +func (a *UserInternalAPI) crossSigningKeysFromDatabase( + ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse, +) { + for targetUserID := range req.UserToDevices { + keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) + if err != nil { + logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID) + continue + } + + for keyType, key := range keys { + var keyID gomatrixserverlib.KeyID + for id := range key.Keys { + keyID = id + break + } + + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) + if err != nil && err != sql.ErrNoRows { + logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID) + continue + } + + appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) { + if key.Signatures == nil { + key.Signatures = types.CrossSigningSigMap{} + } + if _, ok := key.Signatures[originUserID]; !ok { + key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes) + } + key.Signatures[originUserID][originKeyID] = signature + } + + for originUserID, forOrigin := range sigMap { + for originKeyID, signature := range forOrigin { + switch { + case req.UserID != "" && originUserID == req.UserID: + // Include signatures that we created + appendSignature(originUserID, originKeyID, signature) + case originUserID == targetUserID: + // Include signatures that were created by the person whose key + // we are processing + appendSignature(originUserID, originKeyID, signature) + } + } + } + + switch keyType { + case gomatrixserverlib.CrossSigningKeyPurposeMaster: + res.MasterKeys[targetUserID] = key + + case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: + res.SelfSigningKeys[targetUserID] = key + + case gomatrixserverlib.CrossSigningKeyPurposeUserSigning: + res.UserSigningKeys[targetUserID] = key + } + } + } +} + +func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { + for targetUserID, forTargetUser := range req.TargetIDs { + keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID) + if err != nil && err != sql.ErrNoRows { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err), + } + continue + } + + for targetPurpose, targetKey := range keyMap { + switch targetPurpose { + case gomatrixserverlib.CrossSigningKeyPurposeMaster: + if res.MasterKeys == nil { + res.MasterKeys = map[string]gomatrixserverlib.CrossSigningKey{} + } + res.MasterKeys[targetUserID] = targetKey + + case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: + if res.SelfSigningKeys == nil { + res.SelfSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{} + } + res.SelfSigningKeys[targetUserID] = targetKey + + case gomatrixserverlib.CrossSigningKeyPurposeUserSigning: + if res.UserSigningKeys == nil { + res.UserSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{} + } + res.UserSigningKeys[targetUserID] = targetKey + } + } + + for _, targetKeyID := range forTargetUser { + // Get own signatures only. + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID) + if err != nil && err != sql.ErrNoRows { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), + } + return nil + } + + for sourceUserID, forSourceUser := range sigMap { + for sourceKeyID, sourceSig := range forSourceUser { + if res.Signatures == nil { + res.Signatures = map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{} + } + if _, ok := res.Signatures[targetUserID]; !ok { + res.Signatures[targetUserID] = map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{} + } + if _, ok := res.Signatures[targetUserID][targetKeyID]; !ok { + res.Signatures[targetUserID][targetKeyID] = types.CrossSigningSigMap{} + } + if _, ok := res.Signatures[targetUserID][targetKeyID][sourceUserID]; !ok { + res.Signatures[targetUserID][targetKeyID][sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + res.Signatures[targetUserID][targetKeyID][sourceUserID][sourceKeyID] = sourceSig + } + } + } + } + return nil +} diff --git a/userapi/internal/device_list_update.go b/userapi/internal/device_list_update.go new file mode 100644 index 00000000..3b4dcf98 --- /dev/null +++ b/userapi/internal/device_list_update.go @@ -0,0 +1,579 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "hash/fnv" + "net" + "sync" + "time" + + rsapi "github.com/matrix-org/dendrite/roomserver/api" + + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/api" +) + +var ( + deviceListUpdateCount = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "keyserver", + Name: "device_list_update", + Help: "Number of times we have attempted to update device lists from this server", + }, + []string{"server"}, + ) +) + +const requestTimeout = time.Second * 30 + +func init() { + prometheus.MustRegister( + deviceListUpdateCount, + ) +} + +// DeviceListUpdater handles device list updates from remote servers. +// +// In the case where we have the prev_id for an update, the updater just stores the update (after acquiring a per-user lock). +// In the case where we do not have the prev_id for an update, the updater marks the user_id as stale and notifies +// a worker to get the latest device list for this user. Note: stream IDs are scoped per user so missing a prev_id +// for a (user, device) does not mean that DEVICE is outdated as the previous ID could be for a different device: +// we have to invalidate all devices for that user. Once the list has been fetched, the per-user lock is acquired and the +// updater stores the latest list along with the latest stream ID. +// +// On startup, the updater spins up N workers which are responsible for querying device keys from remote servers. +// Workers are scoped by homeserver domain, with one worker responsible for many domains, determined by hashing +// mod N the server name. Work is sent via a channel which just serves to "poke" the worker as the data is retrieved +// from the database (which allows us to batch requests to the same server). This has a number of desirable properties: +// - We guarantee only 1 in-flight /keys/query request per server at any time as there is exactly 1 worker responsible +// for that domain. +// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where +// we have many many servers) +// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers. +// +// The downsides are that: +// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free +// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts) +// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests +// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse +// than being stuck behind foo.bar +// +// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is +// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried. +type DeviceListUpdater struct { + process *process.ProcessContext + // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1 + // request to the remote server and race. + // TODO: Put in an LRU cache to bound growth + userIDToMutex map[string]*sync.Mutex + mu *sync.Mutex // protects UserIDToMutex + + db DeviceListUpdaterDatabase + api DeviceListUpdaterAPI + producer KeyChangeProducer + fedClient fedsenderapi.KeyserverFederationAPI + workerChans []chan gomatrixserverlib.ServerName + thisServer gomatrixserverlib.ServerName + + // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will + // block on or timeout via a select. + userIDToChan map[string]chan bool + userIDToChanMu *sync.Mutex + rsAPI rsapi.KeyserverRoomserverAPI +} + +// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. +// Useful for testing. +type DeviceListUpdaterDatabase interface { + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. + // If no domains are given, all user IDs with stale device lists are returned. + StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + + // MarkDeviceListStale sets the stale bit for this user to isStale. + MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error + + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior + // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error + + // PrevIDsExists returns true if all prev IDs exist for this user. + PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) + + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error +} + +type DeviceListUpdaterAPI interface { + PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error +} + +// KeyChangeProducer is the interface for producers.KeyChange useful for testing. +type KeyChangeProducer interface { + ProduceKeyChanges(keys []api.DeviceMessage) error +} + +// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. +func NewDeviceListUpdater( + process *process.ProcessContext, db DeviceListUpdaterDatabase, + api DeviceListUpdaterAPI, producer KeyChangeProducer, + fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, + rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName, +) *DeviceListUpdater { + return &DeviceListUpdater{ + process: process, + userIDToMutex: make(map[string]*sync.Mutex), + mu: &sync.Mutex{}, + db: db, + api: api, + producer: producer, + fedClient: fedClient, + thisServer: thisServer, + workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), + userIDToChan: make(map[string]chan bool), + userIDToChanMu: &sync.Mutex{}, + rsAPI: rsAPI, + } +} + +// Start the device list updater, which will try to refresh any stale device lists. +func (u *DeviceListUpdater) Start() error { + for i := 0; i < len(u.workerChans); i++ { + // Allocate a small buffer per channel. + // If the buffer limit is reached, backpressure will cause the processing of EDUs + // to stop (in this transaction) until key requests can be made. + ch := make(chan gomatrixserverlib.ServerName, 10) + u.workerChans[i] = ch + go u.worker(ch) + } + + staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + if err != nil { + return err + } + offset, step := time.Second*10, time.Second + if max := len(staleLists); max > 120 { + step = (time.Second * 120) / time.Duration(max) + } + for _, userID := range staleLists { + userID := userID // otherwise we are only sending the last entry + time.AfterFunc(offset, func() { + u.notifyWorkers(userID) + }) + offset += step + } + return nil +} + +// CleanUp removes stale device entries for users we don't share a room with anymore +func (u *DeviceListUpdater) CleanUp() error { + staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + if err != nil { + return err + } + + res := rsapi.QueryLeftUsersResponse{} + if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil { + return err + } + + if len(res.LeftUsers) == 0 { + return nil + } + logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers)) + return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers) +} + +func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { + u.mu.Lock() + defer u.mu.Unlock() + if u.userIDToMutex[userID] == nil { + u.userIDToMutex[userID] = &sync.Mutex{} + } + return u.userIDToMutex[userID] +} + +// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it. +// Blocks until the device list is synced or the timeout is reached. +func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error { + mu := u.mutex(userID) + mu.Lock() + err := u.db.MarkDeviceListStale(ctx, userID, true) + mu.Unlock() + if err != nil { + return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err) + } + u.notifyWorkers(userID) + return nil +} + +// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest, +// which assumes when /send 200 OKs that the device lists have been updated. +func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error { + isDeviceListStale, err := u.update(ctx, event) + if err != nil { + return err + } + if isDeviceListStale { + // poke workers to handle stale device lists + u.notifyWorkers(event.UserID) + } + return nil +} + +func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) (bool, error) { + mu := u.mutex(event.UserID) + mu.Lock() + defer mu.Unlock() + // check if we have the prev IDs + exists, err := u.db.PrevIDsExists(ctx, event.UserID, event.PrevID) + if err != nil { + return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err) + } + // if this is the first time we're hearing about this user, sync the device list manually. + if len(event.PrevID) == 0 { + exists = false + } + util.GetLogger(ctx).WithFields(logrus.Fields{ + "prev_ids_exist": exists, + "user_id": event.UserID, + "device_id": event.DeviceID, + "stream_id": event.StreamID, + "prev_ids": event.PrevID, + "display_name": event.DeviceDisplayName, + "deleted": event.Deleted, + }).Trace("DeviceListUpdater.Update") + + // if we haven't missed anything update the database and notify users + if exists || event.Deleted { + k := event.Keys + if event.Deleted { + k = nil + } + keys := []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: event.DeviceID, + DisplayName: event.DeviceDisplayName, + KeyJSON: k, + UserID: event.UserID, + }, + StreamID: event.StreamID, + }, + } + + // DeviceKeysJSON will side-effect modify this, so it needs + // to be a copy, not sharing any pointers with the above. + deviceKeysCopy := *keys[0].DeviceKeys + deviceKeysCopy.KeyJSON = nil + existingKeys := []api.DeviceMessage{ + { + Type: keys[0].Type, + DeviceKeys: &deviceKeysCopy, + StreamID: keys[0].StreamID, + }, + } + + // fetch what keys we had already and only emit changes + if err = u.db.DeviceKeysJSON(ctx, existingKeys); err != nil { + // non-fatal, log and continue + util.GetLogger(ctx).WithError(err).WithField("user_id", event.UserID).Errorf( + "failed to query device keys json for calculating diffs", + ) + } + + err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil) + if err != nil { + return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err) + } + + if err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false); err != nil { + return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err) + } + return false, nil + } + + err = u.db.MarkDeviceListStale(ctx, event.UserID, true) + if err != nil { + return false, fmt.Errorf("failed to mark device list for %s as stale: %w", event.UserID, err) + } + + return true, nil +} + +func (u *DeviceListUpdater) notifyWorkers(userID string) { + _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return + } + hash := fnv.New32a() + _, _ = hash.Write([]byte(remoteServer)) + index := int(int64(hash.Sum32()) % int64(len(u.workerChans))) + + ch := u.assignChannel(userID) + u.workerChans[index] <- remoteServer + select { + case <-ch: + case <-time.After(10 * time.Second): + // we don't return an error in this case as it's not a failure condition. + // we mainly block for the benefit of sytest anyway + } +} + +func (u *DeviceListUpdater) assignChannel(userID string) chan bool { + u.userIDToChanMu.Lock() + defer u.userIDToChanMu.Unlock() + if ch, ok := u.userIDToChan[userID]; ok { + return ch + } + ch := make(chan bool) + u.userIDToChan[userID] = ch + return ch +} + +func (u *DeviceListUpdater) clearChannel(userID string) { + u.userIDToChanMu.Lock() + defer u.userIDToChanMu.Unlock() + if ch, ok := u.userIDToChan[userID]; ok { + close(ch) + delete(u.userIDToChan, userID) + } +} + +func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { + retries := make(map[gomatrixserverlib.ServerName]time.Time) + retriesMu := &sync.Mutex{} + // restarter goroutine which will inject failed servers into ch when it is time + go func() { + var serversToRetry []gomatrixserverlib.ServerName + for { + serversToRetry = serversToRetry[:0] // reuse memory + time.Sleep(time.Second) + retriesMu.Lock() + now := time.Now() + for srv, retryAt := range retries { + if now.After(retryAt) { + serversToRetry = append(serversToRetry, srv) + } + } + for _, srv := range serversToRetry { + delete(retries, srv) + } + retriesMu.Unlock() + for _, srv := range serversToRetry { + ch <- srv + } + } + }() + for serverName := range ch { + retriesMu.Lock() + _, exists := retries[serverName] + retriesMu.Unlock() + if exists { + // Don't retry a server that we're already waiting for. + continue + } + waitTime, shouldRetry := u.processServer(serverName) + if shouldRetry { + retriesMu.Lock() + if _, exists = retries[serverName]; !exists { + retries[serverName] = time.Now().Add(waitTime) + } + retriesMu.Unlock() + } + } +} + +func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) { + ctx := u.process.Context() + logger := util.GetLogger(ctx).WithField("server_name", serverName) + deviceListUpdateCount.WithLabelValues(string(serverName)).Inc() + + waitTime := defaultWaitTime // How long should we wait to try again? + successCount := 0 // How many user requests failed? + + userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) + if err != nil { + logger.WithError(err).Error("Failed to load stale device lists") + return waitTime, true + } + + defer func() { + for _, userID := range userIDs { + // always clear the channel to unblock Update calls regardless of success/failure + u.clearChannel(userID) + } + }() + + for _, userID := range userIDs { + userWait, err := u.processServerUser(ctx, serverName, userID) + if err != nil { + if userWait > waitTime { + waitTime = userWait + } + break + } + successCount++ + } + + allUsersSucceeded := successCount == len(userIDs) + if !allUsersSucceeded { + logger.WithFields(logrus.Fields{ + "total": len(userIDs), + "succeeded": successCount, + "failed": len(userIDs) - successCount, + "wait_time": waitTime, + }).Debug("Failed to query device keys for some users") + } + return waitTime, !allUsersSucceeded +} + +func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) { + ctx, cancel := context.WithTimeout(ctx, requestTimeout) + defer cancel() + logger := util.GetLogger(ctx).WithFields(logrus.Fields{ + "server_name": serverName, + "user_id": userID, + }) + res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return time.Minute * 10, err + } + switch e := err.(type) { + case *json.UnmarshalTypeError, *json.SyntaxError: + logger.WithError(err).Debugf("Device list update for %q contained invalid JSON", userID) + return defaultWaitTime, nil + case *fedsenderapi.FederationClientError: + if e.RetryAfter > 0 { + return e.RetryAfter, err + } else if e.Blacklisted { + return time.Hour * 8, err + } + case net.Error: + // Use the default waitTime, if it's a timeout. + // It probably doesn't make sense to try further users. + if !e.Timeout() { + logger.WithError(e).Debug("GetUserDevices returned net.Error") + return time.Minute * 10, err + } + case gomatrix.HTTPError: + // The remote server returned an error, give it some time to recover. + // This is to avoid spamming remote servers, which may not be Matrix servers anymore. + if e.Code >= 300 { + logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError") + return hourWaitTime, err + } + default: + // Something else failed + logger.WithError(err).Debugf("GetUserDevices returned unknown error type: %T", err) + return time.Minute * 10, err + } + } + if res.UserID != userID { + logger.WithError(err).Debugf("User ID %q in device list update response doesn't match expected %q", res.UserID, userID) + return defaultWaitTime, nil + } + if res.MasterKey != nil || res.SelfSigningKey != nil { + uploadReq := &api.PerformUploadDeviceKeysRequest{ + UserID: userID, + } + uploadRes := &api.PerformUploadDeviceKeysResponse{} + if res.MasterKey != nil { + if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil { + uploadReq.MasterKey = *res.MasterKey + } + } + if res.SelfSigningKey != nil { + if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil { + uploadReq.SelfSigningKey = *res.SelfSigningKey + } + } + _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) + } + err = u.updateDeviceList(&res) + if err != nil { + logger.WithError(err).Error("Fetched device list but failed to store/emit it") + return defaultWaitTime, err + } + return defaultWaitTime, nil +} + +func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error { + ctx := context.Background() // we've got the keys, don't time out when persisting them to the database. + keys := make([]api.DeviceMessage, len(res.Devices)) + existingKeys := make([]api.DeviceMessage, len(res.Devices)) + for i, device := range res.Devices { + keyJSON, err := json.Marshal(device.Keys) + if err != nil { + util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device") + continue + } + keys[i] = api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + StreamID: res.StreamID, + DeviceKeys: &api.DeviceKeys{ + DeviceID: device.DeviceID, + DisplayName: device.DisplayName, + UserID: res.UserID, + KeyJSON: keyJSON, + }, + } + existingKeys[i] = api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + UserID: res.UserID, + DeviceID: device.DeviceID, + }, + } + } + // fetch what keys we had already and only emit changes + if err := u.db.DeviceKeysJSON(ctx, existingKeys); err != nil { + // non-fatal, log and continue + util.GetLogger(ctx).WithError(err).WithField("user_id", res.UserID).Errorf( + "failed to query device keys json for calculating diffs", + ) + } + + err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID}) + if err != nil { + return fmt.Errorf("failed to store remote device keys: %w", err) + } + err = u.db.MarkDeviceListStale(ctx, res.UserID, false) + if err != nil { + return fmt.Errorf("failed to mark device list as fresh: %w", err) + } + err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false) + if err != nil { + return fmt.Errorf("failed to emit key changes for fresh device list: %w", err) + } + return nil +} diff --git a/userapi/internal/device_list_update_default.go b/userapi/internal/device_list_update_default.go new file mode 100644 index 00000000..7d357c95 --- /dev/null +++ b/userapi/internal/device_list_update_default.go @@ -0,0 +1,22 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !vw + +package internal + +import "time" + +const defaultWaitTime = time.Minute +const hourWaitTime = time.Hour diff --git a/userapi/internal/device_list_update_sytest.go b/userapi/internal/device_list_update_sytest.go new file mode 100644 index 00000000..1c60d2eb --- /dev/null +++ b/userapi/internal/device_list_update_sytest.go @@ -0,0 +1,25 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vw + +package internal + +import "time" + +// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite +// results in a one-hour wait time from a previous device so the test times out. This is fine for +// production, but makes an otherwise passing test fail. +const defaultWaitTime = time.Second +const hourWaitTime = time.Second diff --git a/userapi/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go new file mode 100644 index 00000000..868fc9be --- /dev/null +++ b/userapi/internal/device_list_update_test.go @@ -0,0 +1,431 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "crypto/ed25519" + "fmt" + "io" + "net/http" + "net/url" + "reflect" + "strings" + "sync" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" +) + +var ( + ctx = context.Background() +) + +type mockKeyChangeProducer struct { + events []api.DeviceMessage +} + +func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error { + p.events = append(p.events, keys...) + return nil +} + +type mockDeviceListUpdaterDatabase struct { + staleUsers map[string]bool + prevIDsExist func(string, []int64) bool + storedKeys []api.DeviceMessage + mu sync.Mutex // protect staleUsers +} + +func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error { + return nil +} + +// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. +// If no domains are given, all user IDs with stale device lists are returned. +func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + d.mu.Lock() + defer d.mu.Unlock() + var result []string + for userID, isStale := range d.staleUsers { + if !isStale { + continue + } + _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return nil, err + } + if len(domains) == 0 { + result = append(result, userID) + continue + } + for _, d := range domains { + if remoteServer == d { + result = append(result, userID) + break + } + } + } + return result, nil +} + +// MarkDeviceListStale sets the stale bit for this user to isStale. +func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + d.mu.Lock() + defer d.mu.Unlock() + d.staleUsers[userID] = isStale + return nil +} + +func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool { + d.mu.Lock() + defer d.mu.Unlock() + return d.staleUsers[userID] +} + +// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key +// for this (user, device). Does not modify the stream ID for keys. +func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error { + d.storedKeys = append(d.storedKeys, keys...) + return nil +} + +// PrevIDsExists returns true if all prev IDs exist for this user. +func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { + return d.prevIDsExist(userID, prevIDs), nil +} + +func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + return nil +} + +type mockDeviceListUpdaterAPI struct { +} + +func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { + return nil +} + +type roundTripper struct { + fn func(*http.Request) (*http.Response, error) +} + +func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t.fn(req) +} + +func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { + _, pkey, _ := ed25519.GenerateKey(nil) + fedClient := gomatrixserverlib.NewFederationClient( + []*gomatrixserverlib.SigningIdentity{ + { + ServerName: gomatrixserverlib.ServerName("example.test"), + KeyID: gomatrixserverlib.KeyID("ed25519:test"), + PrivateKey: pkey, + }, + }, + ) + fedClient.Client = *gomatrixserverlib.NewClient( + gomatrixserverlib.WithTransport(&roundTripper{tripper}), + ) + return fedClient +} + +// Test that the device keys get persisted and emitted if we have the previous IDs. +func TestUpdateHavePrevID(t *testing.T) { + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int64) bool { + return true + }, + } + ap := &mockDeviceListUpdaterAPI{} + producer := &mockKeyChangeProducer{} + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost") + event := gomatrixserverlib.DeviceListUpdateEvent{ + DeviceDisplayName: "Foo Bar", + Deleted: false, + DeviceID: "FOO", + Keys: []byte(`{"key":"value"}`), + PrevID: []int64{0}, + StreamID: 1, + UserID: "@alice:localhost", + } + err := updater.Update(ctx, event) + if err != nil { + t.Fatalf("Update returned an error: %s", err) + } + want := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + StreamID: event.StreamID, + DeviceKeys: &api.DeviceKeys{ + DeviceID: event.DeviceID, + DisplayName: event.DeviceDisplayName, + KeyJSON: event.Keys, + UserID: event.UserID, + }, + } + if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { + t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) + } + if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { + t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) + } + if db.isStale(event.UserID) { + t.Errorf("%s incorrectly marked as stale", event.UserID) + } +} + +// Test that device keys are fetched from the remote server if we are missing prev IDs +// and that the user's devices are marked as stale until it succeeds. +func TestUpdateNoPrevID(t *testing.T) { + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int64) bool { + return false + }, + } + ap := &mockDeviceListUpdaterAPI{} + producer := &mockKeyChangeProducer{} + remoteUserID := "@alice:example.somewhere" + var wg sync.WaitGroup + wg.Add(1) + keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` + fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { + defer wg.Done() + if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) { + return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) + } + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(` + { + "user_id": "` + remoteUserID + `", + "stream_id": 5, + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": ` + keyJSON + `, + "device_display_name": "Mobile Phone" + } + ] + } + `)), + }, nil + }) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test") + if err := updater.Start(); err != nil { + t.Fatalf("failed to start updater: %s", err) + } + event := gomatrixserverlib.DeviceListUpdateEvent{ + DeviceDisplayName: "Mobile Phone", + Deleted: false, + DeviceID: "another_device_id", + Keys: []byte(`{"key":"value"}`), + PrevID: []int64{3}, + StreamID: 4, + UserID: remoteUserID, + } + err := updater.Update(ctx, event) + + if err != nil { + t.Fatalf("Update returned an error: %s", err) + } + t.Log("waiting for /users/devices to be called...") + wg.Wait() + // wait a bit for db to be updated... + time.Sleep(100 * time.Millisecond) + want := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + StreamID: 5, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "JLAFKJWSCS", + DisplayName: "Mobile Phone", + UserID: remoteUserID, + KeyJSON: []byte(keyJSON), + }, + } + // Now we should have a fresh list and the keys and emitted something + if db.isStale(event.UserID) { + t.Errorf("%s still marked as stale", event.UserID) + } + if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { + t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON)) + t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) + } + if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { + t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) + } + +} + +// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the +// update is still ongoing. +func TestDebounce(t *testing.T) { + t.Skipf("panic on closed channel on GHA") + db := &mockDeviceListUpdaterDatabase{ + staleUsers: make(map[string]bool), + prevIDsExist: func(string, []int64) bool { + return true + }, + } + ap := &mockDeviceListUpdaterAPI{} + producer := &mockKeyChangeProducer{} + fedCh := make(chan *http.Response, 1) + srv := gomatrixserverlib.ServerName("example.com") + userID := "@alice:example.com" + keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` + incomingFedReq := make(chan struct{}) + fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { + if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) { + return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) + } + close(incomingFedReq) + return <-fedCh, nil + }) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost") + if err := updater.Start(); err != nil { + t.Fatalf("failed to start updater: %s", err) + } + + // hit this 5 times + var wg sync.WaitGroup + wg.Add(5) + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil { + t.Errorf("ManualUpdate: %s", err) + } + }() + } + + // wait until the updater hits federation + select { + case <-incomingFedReq: + case <-time.After(time.Second): + t.Fatalf("timed out waiting for updater to hit federation") + } + + // user should be marked as stale + if !db.isStale(userID) { + t.Errorf("user %s not marked as stale", userID) + } + // now send the response over federation + fedCh <- &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(` + { + "user_id": "` + userID + `", + "stream_id": 5, + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": ` + keyJSON + `, + "device_display_name": "Mobile Phone" + } + ] + } + `)), + } + close(fedCh) + // wait until all 5 ManualUpdates return. If we hit federation again we won't send a response + // and should panic with read on a closed channel + wg.Wait() + + // user is no longer stale now + if db.isStale(userID) { + t.Errorf("user %s is marked as stale", userID) + } +} + +func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { + t.Helper() + + base, _, _ := testrig.Base(nil) + connStr, clearDB := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) + if err != nil { + t.Fatal(err) + } + + return db, clearDB +} + +type mockKeyserverRoomserverAPI struct { + leftUsers []string +} + +func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { + res.LeftUsers = m.leftUsers + return nil +} + +func TestDeviceListUpdater_CleanUp(t *testing.T) { + processCtx := process.NewProcessContext() + + alice := test.NewUser(t) + bob := test.NewUser(t) + + // Bob is not joined to any of our rooms + rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}} + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clearDB := mustCreateKeyserverDB(t, dbType) + defer clearDB() + + // This should not get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil { + t.Error(err) + } + + // this one should get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil { + t.Error(err) + } + + updater := NewDeviceListUpdater(processCtx, db, nil, + nil, nil, + 0, rsAPI, "test") + if err := updater.CleanUp(); err != nil { + t.Error(err) + } + + // check that we still have Alice in our stale list + staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Error(err) + } + + // There should only be Alice + wantCount := 1 + if count := len(staleUsers); count != wantCount { + t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count) + } + + if staleUsers[0] != alice.ID { + t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID) + } + }) +} diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go new file mode 100644 index 00000000..be816fe5 --- /dev/null +++ b/userapi/internal/key_api.go @@ -0,0 +1,798 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/matrix-org/dendrite/userapi/api" +) + +func (a *UserInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error { + userIDs, latest, err := a.KeyDatabase.KeyChanges(ctx, req.Offset, req.ToOffset) + if err != nil { + res.Error = &api.KeyError{ + Err: err.Error(), + } + return nil + } + res.Offset = latest + res.UserIDs = userIDs + return nil +} + +func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error { + res.KeyErrors = make(map[string]map[string]*api.KeyError) + if len(req.DeviceKeys) > 0 { + a.uploadLocalDeviceKeys(ctx, req, res) + } + if len(req.OneTimeKeys) > 0 { + a.uploadOneTimeKeys(ctx, req, res) + } + otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + if err != nil { + return err + } + res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks} + return nil +} + +func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { + res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) + res.Failures = make(map[string]interface{}) + // wrap request map in a top-level by-domain map + domainToDeviceKeys := make(map[string]map[string]map[string]string) + for userID, val := range req.OneTimeKeys { + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + continue // ignore invalid users + } + nested, ok := domainToDeviceKeys[string(serverName)] + if !ok { + nested = make(map[string]map[string]string) + } + nested[userID] = val + domainToDeviceKeys[string(serverName)] = nested + } + for domain, local := range domainToDeviceKeys { + if !a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + continue + } + // claim local keys + keys, err := a.KeyDatabase.ClaimKeys(ctx, local) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err), + } + } + util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys") + for _, key := range keys { + _, ok := res.OneTimeKeys[key.UserID] + if !ok { + res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage) + } + _, ok = res.OneTimeKeys[key.UserID][key.DeviceID] + if !ok { + res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage) + } + for keyID, keyJSON := range key.KeyJSON { + res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON + } + } + delete(domainToDeviceKeys, domain) + } + if len(domainToDeviceKeys) > 0 { + a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) + } + return nil +} + +func (a *UserInternalAPI) claimRemoteKeys( + ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, +) { + var wg sync.WaitGroup // Wait for fan-out goroutines to finish + var mu sync.Mutex // Protects the response struct + var claimed int // Number of keys claimed in total + var failures int // Number of servers we failed to ask + + util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys)) + wg.Add(len(domainToDeviceKeys)) + + for d, k := range domainToDeviceKeys { + go func(domain string, keysToClaim map[string]map[string]string) { + fedCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + defer wg.Done() + + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) + + mu.Lock() + defer mu.Unlock() + + if err != nil { + util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed") + res.Failures[domain] = map[string]interface{}{ + "message": err.Error(), + } + failures++ + return + } + + for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys { + res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage) + for deviceID, keys := range deviceIDToKeys { + res.OneTimeKeys[userID][deviceID] = keys + claimed += len(keys) + } + } + }(d, k) + } + + wg.Wait() + util.GetLogger(ctx).WithFields(logrus.Fields{ + "num_keys": claimed, + "num_failures": failures, + }).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys)) +} + +func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { + if err := a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("Failed to delete device keys: %s", err), + } + } + return nil +} + +func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error { + count, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("Failed to query OTK counts: %s", err), + } + return nil + } + res.Count = *count + return nil +} + +func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { + msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query DB for device keys: %s", err), + } + return nil + } + maxStreamID := int64(0) + // remove deleted devices + var result []api.DeviceMessage + for _, m := range msgs { + if m.StreamID > maxStreamID { + maxStreamID = m.StreamID + } + if m.KeyJSON == nil || len(m.KeyJSON) == 0 { + continue + } + result = append(result, m) + } + res.Devices = result + res.StreamID = maxStreamID + return nil +} + +// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present +// in our database. +func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error { + knownDevices, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, []string{}, true) + if err != nil { + return err + } + if len(knownDevices) == 0 { + return nil // fmt.Errorf("unknown user %s", req.UserID) + } + + for i := range knownDevices { + if knownDevices[i].DeviceID == req.DeviceID { + return nil // we already know about this device + } + } + + return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID) +} + +// nolint:gocyclo +func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { + var respMu sync.Mutex + res.DeviceKeys = make(map[string]map[string]json.RawMessage) + res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) + res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) + res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) + res.Failures = make(map[string]interface{}) + + // make a map from domain to device keys + domainToDeviceKeys := make(map[string]map[string][]string) + domainToCrossSigningKeys := make(map[string]map[string]struct{}) + for userID, deviceIDs := range req.UserToDevices { + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + continue // ignore invalid users + } + domain := string(serverName) + // query local devices + if a.Config.Matrix.IsLocalServerName(serverName) { + deviceKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query local device keys: %s", err), + } + return nil + } + + // pull out display names after we have the keys so we handle wildcards correctly + var dids []string + for _, dk := range deviceKeys { + dids = append(dids, dk.DeviceID) + } + var queryRes api.QueryDeviceInfosResponse + err = a.QueryDeviceInfos(ctx, &api.QueryDeviceInfosRequest{ + DeviceIDs: dids, + }, &queryRes) + if err != nil { + util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") + } + + if res.DeviceKeys[userID] == nil { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + } + for _, dk := range deviceKeys { + if len(dk.KeyJSON) == 0 { + continue // don't include blank keys + } + // inject display name if known (either locally or remotely) + displayName := dk.DisplayName + if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" { + displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName + } + dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { + DisplayName string `json:"device_display_name,omitempty"` + }{displayName}) + res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON + } + } else { + domainToDeviceKeys[domain] = make(map[string][]string) + domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) + } + // work out if our cross-signing request for this user was + // satisfied, if not add them to the list of things to fetch + if _, ok := res.MasterKeys[userID]; !ok { + if _, ok := domainToCrossSigningKeys[domain]; !ok { + domainToCrossSigningKeys[domain] = make(map[string]struct{}) + } + domainToCrossSigningKeys[domain][userID] = struct{}{} + } + if _, ok := res.SelfSigningKeys[userID]; !ok { + if _, ok := domainToCrossSigningKeys[domain]; !ok { + domainToCrossSigningKeys[domain] = make(map[string]struct{}) + } + domainToCrossSigningKeys[domain][userID] = struct{}{} + } + } + + // attempt to satisfy key queries from the local database first as we should get device updates pushed to us + domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys) + if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 { + // perform key queries for remote devices + a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) + } + + // Now that we've done the potentially expensive work of asking the federation, + // try filling the cross-signing keys from the database that we know about. + a.crossSigningKeysFromDatabase(ctx, req, res) + + // Finally, append signatures that we know about + // TODO: This is horrible because we need to round-trip the signature from + // JSON, add the signatures and marshal it again, for some reason? + + for targetUserID, masterKey := range res.MasterKeys { + if masterKey.Signatures == nil { + masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + for targetKeyID := range masterKey.Keys { + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) + if err != nil { + // Stop executing the function if the context was canceled/the deadline was exceeded, + // as we can't continue without a valid context. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil + } + logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") + continue + } + if len(sigMap) == 0 { + continue + } + for sourceUserID, forSourceUser := range sigMap { + for sourceKeyID, sourceSig := range forSourceUser { + if _, ok := masterKey.Signatures[sourceUserID]; !ok { + masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig + } + } + } + } + + for targetUserID, forUserID := range res.DeviceKeys { + for targetKeyID, key := range forUserID { + sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID)) + if err != nil { + // Stop executing the function if the context was canceled/the deadline was exceeded, + // as we can't continue without a valid context. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil + } + logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed") + continue + } + if len(sigMap) == 0 { + continue + } + var deviceKey gomatrixserverlib.DeviceKeys + if err = json.Unmarshal(key, &deviceKey); err != nil { + continue + } + if deviceKey.Signatures == nil { + deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + for sourceUserID, forSourceUser := range sigMap { + for sourceKeyID, sourceSig := range forSourceUser { + if _, ok := deviceKey.Signatures[sourceUserID]; !ok { + deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig + } + } + if js, err := json.Marshal(deviceKey); err == nil { + res.DeviceKeys[targetUserID][targetKeyID] = js + } + } + } + return nil +} + +func (a *UserInternalAPI) remoteKeysFromDatabase( + ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string, +) map[string]map[string][]string { + fetchRemote := make(map[string]map[string][]string) + for domain, userToDeviceMap := range domainToDeviceKeys { + for userID, deviceIDs := range userToDeviceMap { + // we can't safely return keys from the db when all devices are requested as we don't + // know if one has just been added. + if len(deviceIDs) > 0 { + err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs) + if err == nil { + continue + } + util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase") + } + // fetch device lists from remote + if _, ok := fetchRemote[domain]; !ok { + fetchRemote[domain] = make(map[string][]string) + } + fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...) + + } + } + return fetchRemote +} + +func (a *UserInternalAPI) queryRemoteKeys( + ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, + domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{}, +) { + resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys)) + // allows us to wait until all federation servers have been poked + var wg sync.WaitGroup + // mutex for writing directly to res (e.g failures) + var respMu sync.Mutex + + domains := map[string]struct{}{} + for domain := range domainToDeviceKeys { + if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + continue + } + domains[domain] = struct{}{} + } + for domain := range domainToCrossSigningKeys { + if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + continue + } + domains[domain] = struct{}{} + } + wg.Add(len(domains)) + + // fan out + for domain := range domains { + go a.queryRemoteKeysOnServer( + ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain], + &wg, &respMu, timeout, resultCh, res, + ) + } + + // Close the result channel when the goroutines have quit so the for .. range exits + go func() { + wg.Wait() + close(resultCh) + }() + + processResult := func(result *gomatrixserverlib.RespQueryKeys) { + respMu.Lock() + defer respMu.Unlock() + for userID, nest := range result.DeviceKeys { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + for deviceID, deviceKey := range nest { + keyJSON, err := json.Marshal(deviceKey) + if err != nil { + continue + } + res.DeviceKeys[userID][deviceID] = keyJSON + } + } + + for userID, body := range result.MasterKeys { + res.MasterKeys[userID] = body + } + + for userID, body := range result.SelfSigningKeys { + res.SelfSigningKeys[userID] = body + } + + // TODO: do we want to persist these somewhere now + // that we have fetched them? + } + + for result := range resultCh { + processResult(result) + } +} + +func (a *UserInternalAPI) queryRemoteKeysOnServer( + ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{}, + wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys, + res *api.QueryKeysResponse, +) { + defer wg.Done() + fedCtx := ctx + if timeout > 0 { + var cancel context.CancelFunc + fedCtx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + // for users who we do not have any knowledge about, try to start doing device list updates for them + // by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but + // lack a stream ID. + userIDsForAllDevices := map[string]struct{}{} + for userID, deviceIDs := range devKeys { + if len(deviceIDs) == 0 { + userIDsForAllDevices[userID] = struct{}{} + } + } + // for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing + // a device list update, so we'll populate those back into the /keys/query list if not + for userID := range crossSigningKeys { + if devKeys == nil { + devKeys = map[string][]string{} + } + if _, ok := userIDsForAllDevices[userID]; !ok { + devKeys[userID] = []string{} + } + } + for userID := range userIDsForAllDevices { + err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID) + if err != nil { + logrus.WithFields(logrus.Fields{ + logrus.ErrorKey: err, + "user_id": userID, + "server": serverName, + }).Error("Failed to manually update device lists for user") + // try to do it via /keys/query + devKeys[userID] = []string{} + continue + } + // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this + // user so the fact that we're populating all devices here isn't a problem so long as we have devices. + err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil) + if err != nil { + logrus.WithFields(logrus.Fields{ + logrus.ErrorKey: err, + "user_id": userID, + "server": serverName, + }).Error("Failed to manually update device lists for user") + // try to do it via /keys/query + devKeys[userID] = []string{} + continue + } + } + if len(devKeys) == 0 { + return + } + queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) + if err == nil { + resultCh <- &queryKeysResp + return + } + respMu.Lock() + res.Failures[serverName] = map[string]interface{}{ + "message": err.Error(), + } + respMu.Unlock() + + // last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server + // is down, better to return something than nothing at all. Clients can know about the failure by + // inspecting the failures map though so they can know it's a cached response. + for userID, dkeys := range devKeys { + // drop the error as it's already a failure at this point + _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys) + } + + // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache + respMu.Lock() + if len(res.DeviceKeys) > 0 { + delete(res.Failures, serverName) + } + respMu.Unlock() +} + +func (a *UserInternalAPI) populateResponseWithDeviceKeysFromDatabase( + ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string, +) error { + keys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false) + // if we can't query the db or there are fewer keys than requested, fetch from remote. + if err != nil { + return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) + } + if len(keys) < len(deviceIDs) { + return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID) + } + if len(deviceIDs) == 0 && len(keys) == 0 { + return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID) + } + respMu.Lock() + if res.DeviceKeys[userID] == nil { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + } + respMu.Unlock() + + for _, key := range keys { + if len(key.KeyJSON) == 0 { + continue // ignore deleted keys + } + // inject the display name + key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { + DisplayName string `json:"device_display_name,omitempty"` + }{key.DisplayName}) + respMu.Lock() + res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON + respMu.Unlock() + } + return nil +} + +func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + // get a list of devices from the user API that actually exist, as + // we won't store keys for devices that don't exist + uapidevices := &api.QueryDevicesResponse{} + if err := a.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { + res.Error = &api.KeyError{ + Err: err.Error(), + } + return + } + if !uapidevices.UserExists { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("user %q does not exist", req.UserID), + } + return + } + existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices)) + for _, key := range uapidevices.Devices { + existingDeviceMap[key.ID] = struct{}{} + } + + // Get all of the user existing device keys so we can check for changes. + existingKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, true) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), + } + return + } + + // Work out whether we have device keys in the keyserver for devices that + // no longer exist in the user API. This is mostly an exercise to ensure + // that we keep some integrity between the two. + var toClean []gomatrixserverlib.KeyID + for _, k := range existingKeys { + if _, ok := existingDeviceMap[k.DeviceID]; !ok { + toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID)) + } + } + + if len(toClean) > 0 { + if err = a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { + logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean)) + } else { + logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean)) + } + } + + var keysToStore []api.DeviceMessage + + if req.OnlyDisplayNameUpdates { + for _, existingKey := range existingKeys { + for _, newKey := range req.DeviceKeys { + switch { + case existingKey.UserID != newKey.UserID: + continue + case existingKey.DeviceID != newKey.DeviceID: + continue + case existingKey.DisplayName != newKey.DisplayName: + existingKey.DisplayName = newKey.DisplayName + } + } + keysToStore = append(keysToStore, existingKey) + } + } else { + // assert that the user ID / device ID are not lying for each key + for _, key := range req.DeviceKeys { + var serverName gomatrixserverlib.ServerName + _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID) + if err != nil { + continue // ignore invalid users + } + if !a.Config.Matrix.IsLocalServerName(serverName) { + continue // ignore remote users + } + if len(key.KeyJSON) == 0 { + keysToStore = append(keysToStore, key.WithStreamID(0)) + continue // deleted keys don't need sanity checking + } + // check that the device in question actually exists in the user + // API before we try and store a key for it + if _, ok := existingDeviceMap[key.DeviceID]; !ok { + continue + } + gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str + gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str + if gotUserID == key.UserID && gotDeviceID == key.DeviceID { + keysToStore = append(keysToStore, key.WithStreamID(0)) + continue + } + + res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ + Err: fmt.Sprintf( + "user_id or device_id mismatch: users: %s - %s, devices: %s - %s", + gotUserID, key.UserID, gotDeviceID, key.DeviceID, + ), + }) + } + } + + // store the device keys and emit changes + err = a.KeyDatabase.StoreLocalDeviceKeys(ctx, keysToStore) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), + } + return + } + err = emitDeviceKeyChanges(a.KeyChangeProducer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates) + if err != nil { + util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) + } +} + +func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { + if req.UserID == "" { + res.Error = &api.KeyError{ + Err: "user ID missing", + } + } + if req.DeviceID != "" && len(req.OneTimeKeys) == 0 { + counts, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("a.KeyDatabase.OneTimeKeysCount: %s", err), + } + } + if counts != nil { + res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts) + } + return + } + for _, key := range req.OneTimeKeys { + // grab existing keys based on (user/device/algorithm/key ID) + keyIDsWithAlgorithms := make([]string, len(key.KeyJSON)) + i := 0 + for keyIDWithAlgo := range key.KeyJSON { + keyIDsWithAlgorithms[i] = keyIDWithAlgo + i++ + } + existingKeys, err := a.KeyDatabase.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms) + if err != nil { + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ + Err: "failed to query existing one-time keys: " + err.Error(), + }) + continue + } + for keyIDWithAlgo := range existingKeys { + // if keys exist and the JSON doesn't match, error out as the key already exists + if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) { + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ + Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo), + }) + continue + } + } + // store one-time keys + counts, err := a.KeyDatabase.StoreOneTimeKeys(ctx, key) + if err != nil { + res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ + Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()), + }) + continue + } + // collect counts + res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts) + } + +} + +func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error { + // if we only want to update the display names, we can skip the checks below + if onlyUpdateDisplayName { + return producer.ProduceKeyChanges(new) + } + // find keys in new that are not in existing + var keysAdded []api.DeviceMessage + for _, newKey := range new { + exists := false + for _, existingKey := range existing { + // Do not treat the absence of keys as equal, or else we will not emit key changes + // when users delete devices which never had a key to begin with as both KeyJSONs are nil. + if existingKey.DeviceKeysEqual(&newKey) { + exists = true + break + } + } + if !exists { + keysAdded = append(keysAdded, newKey) + } + } + return producer.ProduceKeyChanges(keysAdded) +} diff --git a/userapi/internal/key_api_test.go b/userapi/internal/key_api_test.go new file mode 100644 index 00000000..fc7e7e0d --- /dev/null +++ b/userapi/internal/key_api_test.go @@ -0,0 +1,161 @@ +package internal_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/internal" + "github.com/matrix-org/dendrite/userapi/storage" +) + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + base, _, _ := testrig.Base(nil) + db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to create new user db: %v", err) + } + return db, func() { + base.Close() + close() + } +} + +func Test_QueryDeviceMessages(t *testing.T) { + alice := test.NewUser(t) + type args struct { + req *api.QueryDeviceMessagesRequest + res *api.QueryDeviceMessagesResponse + } + tests := []struct { + name string + args args + wantErr bool + want *api.QueryDeviceMessagesResponse + }{ + { + name: "no existing keys", + args: args{ + req: &api.QueryDeviceMessagesRequest{ + UserID: "@doesNotExist:localhost", + }, + res: &api.QueryDeviceMessagesResponse{}, + }, + want: &api.QueryDeviceMessagesResponse{}, + }, + { + name: "existing user returns devices", + args: args{ + req: &api.QueryDeviceMessagesRequest{ + UserID: alice.ID, + }, + res: &api.QueryDeviceMessagesResponse{}, + }, + want: &api.QueryDeviceMessagesResponse{ + StreamID: 6, + Devices: []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, StreamID: 5, DeviceKeys: &api.DeviceKeys{ + DeviceID: "myDevice", + DisplayName: "first device", + UserID: alice.ID, + KeyJSON: []byte("ghi"), + }, + }, + { + Type: api.TypeDeviceKeyUpdate, StreamID: 6, DeviceKeys: &api.DeviceKeys{ + DeviceID: "mySecondDevice", + DisplayName: "second device", + UserID: alice.ID, + KeyJSON: []byte("jkl"), + }, // streamID 6 + }, + }, + }, + }, + } + + deviceMessages := []api.DeviceMessage{ + { // not the user we're looking for + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + UserID: "@doesNotExist:localhost", + }, + // streamID 1 for this user + }, + { // empty keyJSON will be ignored + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + DeviceID: "myDevice", + UserID: alice.ID, + }, // streamID 1 + }, + { + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + DeviceID: "myDevice", + UserID: alice.ID, + KeyJSON: []byte("abc"), + }, // streamID 2 + }, + { + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + DeviceID: "myDevice", + UserID: alice.ID, + KeyJSON: []byte("def"), + }, // streamID 3 + }, + { + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + DeviceID: "myDevice", + UserID: alice.ID, + KeyJSON: []byte(""), + }, // streamID 4 + }, + { + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + DeviceID: "myDevice", + DisplayName: "first device", + UserID: alice.ID, + KeyJSON: []byte("ghi"), + }, // streamID 5 + }, + { + Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ + DeviceID: "mySecondDevice", + UserID: alice.ID, + KeyJSON: []byte("jkl"), + DisplayName: "second device", + }, // streamID 6 + }, + } + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateDatabase(t, dbType) + defer closeDB() + if err := db.StoreLocalDeviceKeys(ctx, deviceMessages); err != nil { + t.Fatalf("failed to store local devicesKeys") + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &internal.UserInternalAPI{ + KeyDatabase: db, + } + if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr { + t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr) + } + got := tt.args.res + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("QueryDeviceMessages(): got:\n%+v, want:\n%+v", got, tt.want) + } + }) + } + }) +} diff --git a/userapi/internal/api.go b/userapi/internal/user_api.go index 0bb480da..1cbd9719 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/user_api.go @@ -23,6 +23,7 @@ import ( "strconv" "time" + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -32,7 +33,6 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/sqlutil" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" synctypes "github.com/matrix-org/dendrite/syncapi/types" @@ -44,17 +44,19 @@ import ( ) type UserInternalAPI struct { - DB storage.Database - SyncProducer *producers.SyncAPI - Config *config.UserAPI + DB storage.UserDatabase + KeyDatabase storage.KeyDatabase + SyncProducer *producers.SyncAPI + KeyChangeProducer *producers.KeyChange + Config *config.UserAPI DisableTLSValidation bool // AppServices is the list of all registered AS AppServices []config.ApplicationService - KeyAPI keyapi.UserKeyAPI RSAPI rsapi.UserRoomserverAPI PgClient pushgateway.Client - Cfg *config.UserAPI + FedClient fedsenderapi.KeyserverFederationAPI + Updater *DeviceListUpdater } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { @@ -221,7 +223,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return fmt.Errorf("a.DB.SetDisplayName: %w", err) } - postRegisterJoinRooms(a.Cfg, acc, a.RSAPI) + postRegisterJoinRooms(a.Config, acc, a.RSAPI) res.AccountCreated = true res.Account = acc @@ -293,14 +295,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe return err } // Ask the keyserver to delete device keys and signatures for those devices - deleteReq := &keyapi.PerformDeleteKeysRequest{ + deleteReq := &api.PerformDeleteKeysRequest{ UserID: req.UserID, } for _, keyID := range req.DeviceIDs { deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID)) } - deleteRes := &keyapi.PerformDeleteKeysResponse{} - if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil { + deleteRes := &api.PerformDeleteKeysResponse{} + if err := a.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil { return err } if err := deleteRes.Error; err != nil { @@ -311,17 +313,17 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe } func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { - deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs)) + deviceKeys := make([]api.DeviceKeys, len(deviceIDs)) for i, did := range deviceIDs { - deviceKeys[i] = keyapi.DeviceKeys{ + deviceKeys[i] = api.DeviceKeys{ UserID: userID, DeviceID: did, KeyJSON: nil, } } - var uploadRes keyapi.PerformUploadKeysResponse - if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + var uploadRes api.PerformUploadKeysResponse + if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{ UserID: userID, DeviceKeys: deviceKeys, }, &uploadRes); err != nil { @@ -385,10 +387,10 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf } if req.DisplayName != nil && dev.DisplayName != *req.DisplayName { // display name has changed: update the device key - var uploadRes keyapi.PerformUploadKeysResponse - if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ + var uploadRes api.PerformUploadKeysResponse + if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{ UserID: req.RequestingUserID, - DeviceKeys: []keyapi.DeviceKeys{ + DeviceKeys: []api.DeviceKeys{ { DeviceID: dev.ID, DisplayName: *req.DisplayName, diff --git a/userapi/producers/keychange.go b/userapi/producers/keychange.go new file mode 100644 index 00000000..da6cea31 --- /dev/null +++ b/userapi/producers/keychange.go @@ -0,0 +1,107 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package producers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" +) + +// KeyChange produces key change events for the sync API and federation sender to consume +type KeyChange struct { + Topic string + JetStream JetStreamPublisher + DB storage.KeyChangeDatabase +} + +// ProduceKeyChanges creates new change events for each key +func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { + userToDeviceCount := make(map[string]int) + for _, key := range keys { + id, err := p.DB.StoreKeyChange(context.Background(), key.UserID) + if err != nil { + return err + } + key.DeviceChangeID = id + value, err := json.Marshal(key) + if err != nil { + return err + } + + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, key.UserID) + m.Data = value + + _, err = p.JetStream.PublishMsg(m) + if err != nil { + return err + } + + userToDeviceCount[key.UserID]++ + } + for userID, count := range userToDeviceCount { + logrus.WithFields(logrus.Fields{ + "user_id": userID, + "num_key_changes": count, + }).Tracef("Produced to key change topic '%s'", p.Topic) + } + return nil +} + +func (p *KeyChange) ProduceSigningKeyUpdate(key api.CrossSigningKeyUpdate) error { + output := &api.DeviceMessage{ + Type: api.TypeCrossSigningUpdate, + OutputCrossSigningKeyUpdate: &api.OutputCrossSigningKeyUpdate{ + CrossSigningKeyUpdate: key, + }, + } + + id, err := p.DB.StoreKeyChange(context.Background(), key.UserID) + if err != nil { + return err + } + output.DeviceChangeID = id + + value, err := json.Marshal(output) + if err != nil { + return err + } + + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, key.UserID) + m.Data = value + + _, err = p.JetStream.PublishMsg(m) + if err != nil { + return err + } + + logrus.WithFields(logrus.Fields{ + "user_id": key.UserID, + }).Tracef("Produced to cross-signing update topic '%s'", p.Topic) + return nil +} diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go index 51eaa985..165de899 100644 --- a/userapi/producers/syncapi.go +++ b/userapi/producers/syncapi.go @@ -19,13 +19,13 @@ type JetStreamPublisher interface { // SyncAPI produces messages for the Sync API server to consume. type SyncAPI struct { - db storage.Database + db storage.Notification producer JetStreamPublisher clientDataTopic string notificationDataTopic string } -func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { +func NewSyncAPI(db storage.UserDatabase, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { return &SyncAPI{ db: db, producer: js, diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index c22b7658..27837886 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -90,7 +90,7 @@ type KeyBackup interface { type LoginToken interface { // CreateLoginToken generates a token, stores and returns it. The lifetime is - // determined by the loginTokenLifetime given to the Database constructor. + // determined by the loginTokenLifetime given to the UserDatabase constructor. CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) // RemoveLoginToken removes the named token (and may clean up other expired tokens). @@ -130,7 +130,7 @@ type Notification interface { DeleteOldNotifications(ctx context.Context) error } -type Database interface { +type UserDatabase interface { Account AccountData Device @@ -144,6 +144,78 @@ type Database interface { ThreePID } +type KeyChangeDatabase interface { + // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. + // `userID` is the the user who has changed their keys in some way. + StoreKeyChange(ctx context.Context, userID string) (int64, error) +} + +type KeyDatabase interface { + KeyChangeDatabase + // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination + // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. + ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + + // StoreOneTimeKeys persists the given one-time keys. + StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + + // OneTimeKeysCount returns a count of all OTKs for this device. + OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). + // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. + // Returns an error if there was a problem storing the keys. + StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior + // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error + + // PrevIDsExists returns true if all prev IDs exist for this user. + PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) + + // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. + // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) + + // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying + // cross-signing signatures relating to that device. + DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error + + // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key + // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. + ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) + + // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). + // A to offset of types.OffsetNewest means no upper limit. + // Returns the offset of the latest key change. + KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) + + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. + // If no domains are given, all user IDs with stale device lists are returned. + StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + + // MarkDeviceListStale sets the stale bit for this user to isStale. + MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error + + CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) + CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) + CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) + + StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error + StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + + DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, + ) error +} + type Statistics interface { UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 2a4777d7..05716037 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData( roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) - _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content) + // Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage + // when passing the data to trigger "NOT NULL" constraint + var data *json.RawMessage + if len(content) > 0 { + data = &content + } + _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data) return } diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go new file mode 100644 index 00000000..c0ecbd30 --- /dev/null +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -0,0 +1,102 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var crossSigningKeysSchema = ` +CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( + user_id TEXT NOT NULL, + key_type SMALLINT NOT NULL, + key_data TEXT NOT NULL, + PRIMARY KEY (user_id, key_type) +); +` + +const selectCrossSigningKeysForUserSQL = "" + + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + + " WHERE user_id = $1" + +const upsertCrossSigningKeysForUserSQL = "" + + "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + + " VALUES($1, $2, $3)" + + " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3" + +type crossSigningKeysStatements struct { + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt +} + +func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { + s := &crossSigningKeysStatements{ + db: db, + } + _, err := db.Exec(crossSigningKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, + {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, + }.Prepare(db) +} + +func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, +) (r types.CrossSigningKeyMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") + r = types.CrossSigningKeyMap{} + for rows.Next() { + var keyTypeInt int16 + var keyData gomatrixserverlib.Base64Bytes + if err := rows.Scan(&keyTypeInt, &keyData); err != nil { + return nil, err + } + keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] + if !ok { + return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) + } + r[keyType] = keyData + } + return +} + +func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, +) error { + keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] + if !ok { + return fmt.Errorf("unknown key purpose %q", keyType) + } + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { + return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/cross_signing_sigs_table.go b/userapi/storage/postgres/cross_signing_sigs_table.go new file mode 100644 index 00000000..b0117145 --- /dev/null +++ b/userapi/storage/postgres/cross_signing_sigs_table.go @@ -0,0 +1,131 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var crossSigningSigsSchema = ` +CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) +); + +CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); +` + +const selectCrossSigningSigsForTargetSQL = "" + + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + + " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3" + +const upsertCrossSigningSigsForTargetSQL = "" + + "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + + " VALUES($1, $2, $3, $4, $5)" + + " ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5" + +const deleteCrossSigningSigsForTargetSQL = "" + + "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" + +type crossSigningSigsStatements struct { + db *sql.DB + selectCrossSigningSigsForTargetStmt *sql.Stmt + upsertCrossSigningSigsForTargetStmt *sql.Stmt + deleteCrossSigningSigsForTargetStmt *sql.Stmt +} + +func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { + s := &crossSigningSigsStatements{ + db: db, + } + _, err := db.Exec(crossSigningSigsSchema) + if err != nil { + return nil, err + } + + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "keyserver: cross signing signature indexes", + Up: deltas.UpFixCrossSigningSignatureIndexes, + }) + if err = m.Up(context.Background()); err != nil { + return nil, err + } + + return s, sqlutil.StatementList{ + {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, + {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, + {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL}, + }.Prepare(db) +} + +func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) (r types.CrossSigningSigMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed") + r = types.CrossSigningSigMap{} + for rows.Next() { + var userID string + var keyID gomatrixserverlib.KeyID + var signature gomatrixserverlib.Base64Bytes + if err := rows.Scan(&userID, &keyID, &signature); err != nil { + return nil, err + } + if _, ok := r[userID]; !ok { + r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + r[userID][keyID] = signature + } + return +} + +func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} + +func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) error { + if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { + return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/deltas/2022012016470000_key_changes.go b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go new file mode 100644 index 00000000..0cfe9e79 --- /dev/null +++ b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go @@ -0,0 +1,69 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { + // start counting from the last max offset, else 0. We need to do a count(*) first to see if there + // even are entries in this table to know if we can query for log_offset. Without the count then + // the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't + // exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/ + var count int + _ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count) + if count > 0 { + var maxOffset int64 + _ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset) + if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil { + return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err) + } + } + + _, err := tx.ExecContext(ctx, ` + -- make the new table + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers + DROP SEQUENCE IF EXISTS keyserver_key_changes_seq; + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + log_offset BIGINT NOT NULL, + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go new file mode 100644 index 00000000..1a3d4fee --- /dev/null +++ b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go @@ -0,0 +1,47 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; + ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id); + + CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; + ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id); + + DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx; + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/device_keys_table.go b/userapi/storage/postgres/device_keys_table.go new file mode 100644 index 00000000..a9203857 --- /dev/null +++ b/userapi/storage/postgres/device_keys_table.go @@ -0,0 +1,213 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/lib/pq" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var deviceKeysSchema = ` +-- Stores device keys for users +CREATE TABLE IF NOT EXISTS keyserver_device_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- the stream ID of this key, scoped per-user. This gets updated when the device key changes. + -- This means we do not store an unbounded append-only log of device keys, which is not actually + -- required in the spec because in the event of a missed update the server fetches the entire + -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. + stream_id BIGINT NOT NULL, + display_name TEXT, + -- Clobber based on tuple of user/device. + CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) +); +` + +const upsertDeviceKeysSQL = "" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" + +const selectDeviceKeysSQL = "" + + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +const selectBatchDeviceKeysSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" + +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" + +const countStreamIDsForUserSQL = "" + + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)" + +const deleteDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +const deleteAllDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1" + +type deviceKeysStatements struct { + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + countStreamIDsForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt +} + +func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { + s := &deviceKeysStatements{ + db: db, + } + _, err := db.Exec(deviceKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL}, + {&s.selectDeviceKeysStmt, selectDeviceKeysSQL}, + {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL}, + {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL}, + {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL}, + {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, + {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL}, + {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL}, + }.Prepare(db) +} + +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + for i, key := range keys { + var keyJSONStr string + var streamID int64 + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) + if err != nil && err != sql.ErrNoRows { + return err + } + // this will be '' when there is no device + keys[i].Type = api.TypeDeviceKeyUpdate + keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } + } + return nil +} + +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { + // nullable if there are no results + var nullStream sql.NullInt64 + err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int64 + } + return +} + +func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { + // nullable if there are no results + var count sql.NullInt32 + err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count) + if err != nil { + return 0, err + } + if count.Valid { + return int(count.Int32), nil + } + return 0, nil +} + +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, + ) + if err != nil { + return err + } + } + return nil +} + +func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} + +func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + return err +} + +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") + deviceIDMap := make(map[string]bool) + for _, d := range deviceIDs { + deviceIDMap[d] = true + } + var result []api.DeviceMessage + var displayName sql.NullString + for rows.Next() { + dk := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + UserID: userID, + }, + } + if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dk.DisplayName = displayName.String + } + // include the key if we want all keys (no device) or it was asked + if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { + result = append(result, dk) + } + } + return result, rows.Err() +} diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 7481ac5b..88f8839c 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice( if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { return nil, fmt.Errorf("insertDeviceStmt: %w", err) } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -168,7 +168,11 @@ func (s *devicesStatements) InsertDevice( LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go index 7b58f7ba..91a34c35 100644 --- a/userapi/storage/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -52,7 +52,7 @@ const updateBackupKeySQL = "" + const countKeysSQL = "" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" -const selectKeysSQL = "" + +const selectBackupKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" @@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, - {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysStmt, selectBackupKeysSQL}, {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, }.Prepare(db) diff --git a/userapi/storage/postgres/key_changes_table.go b/userapi/storage/postgres/key_changes_table.go new file mode 100644 index 00000000..a0049414 --- /dev/null +++ b/userapi/storage/postgres/key_changes_table.go @@ -0,0 +1,127 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var keyChangesSchema = ` +-- Stores key change information about users. Used to determine when to send updated device lists to clients. +CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq; +CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) +); +` + +// Replace based on user ID. We don't care how many times the user's keys have changed, only that they +// have changed, hence we can just keep bumping the change ID for this user. +const upsertKeyChangeSQL = "" + + "INSERT INTO keyserver_key_changes (user_id)" + + " VALUES ($1)" + + " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" + + " DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" + + " RETURNING change_id" + +const selectKeyChangesSQL = "" + + "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" + +type keyChangesStatements struct { + db *sql.DB + upsertKeyChangeStmt *sql.Stmt + selectKeyChangesStmt *sql.Stmt +} + +func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { + s := &keyChangesStatements{ + db: db, + } + _, err := db.Exec(keyChangesSchema) + if err != nil { + return s, err + } + + if err = executeMigration(context.Background(), db); err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertKeyChangeStmt, upsertKeyChangeSQL}, + {&s.selectKeyChangesStmt, selectKeyChangesSQL}, + }.Prepare(db) +} + +func executeMigration(ctx context.Context, db *sql.DB) error { + // TODO: Remove when we are sure we are not having goose artefacts in the db + // This forces an error, which indicates the migration is already applied, since the + // column partition was removed from the table + migrationName := "keyserver: refactor key changes" + + var cName string + err := db.QueryRowContext(ctx, "select column_name from information_schema.columns where table_name = 'keyserver_key_changes' AND column_name = 'partition'").Scan(&cName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed + if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { + return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) + } + return nil + } + return err + } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: migrationName, + Up: deltas.UpRefactorKeyChanges, + }) + + return m.Up(ctx) +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { + err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) + return +} + +func (s *keyChangesStatements) SelectKeyChanges( + ctx context.Context, fromOffset, toOffset int64, +) (userIDs []string, latestOffset int64, err error) { + latestOffset = fromOffset + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") + for rows.Next() { + var userID string + var offset int64 + if err := rows.Scan(&userID, &offset); err != nil { + return nil, 0, err + } + if offset > latestOffset { + latestOffset = offset + } + userIDs = append(userIDs, userID) + } + return +} diff --git a/userapi/storage/postgres/one_time_keys_table.go b/userapi/storage/postgres/one_time_keys_table.go new file mode 100644 index 00000000..972a5914 --- /dev/null +++ b/userapi/storage/postgres/one_time_keys_table.go @@ -0,0 +1,194 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var oneTimeKeysSchema = ` +-- Stores one-time public keys for users +CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + key_id TEXT NOT NULL, + algorithm TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- Clobber based on 4-uple of user/device/key/algorithm. + CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id); +` + +const upsertKeysSQL = "" + + "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" + + " DO UPDATE SET key_json = $6" + +const selectOneTimeKeysSQL = "" + + "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);" + +const selectKeysCountSQL = "" + + "SELECT algorithm, COUNT(key_id) FROM " + + " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" + + " x GROUP BY algorithm" + +const deleteOneTimeKeySQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" + +const selectKeyByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" + +const deleteOneTimeKeysSQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" + +type oneTimeKeysStatements struct { + db *sql.DB + upsertKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysCountStmt *sql.Stmt + selectKeyByAlgorithmStmt *sql.Stmt + deleteOneTimeKeyStmt *sql.Stmt + deleteOneTimeKeysStmt *sql.Stmt +} + +func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { + s := &oneTimeKeysStatements{ + db: db, + } + _, err := db.Exec(oneTimeKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertKeysSQL}, + {&s.selectKeysStmt, selectOneTimeKeysSQL}, + {&s.selectKeysCountStmt, selectKeysCountSQL}, + {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL}, + {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL}, + {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL}, + }.Prepare(db) +} + +func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") + + result := make(map[string]json.RawMessage) + var ( + algorithmWithID string + keyJSONStr string + ) + for rows.Next() { + if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil { + return nil, err + } + result[algorithmWithID] = json.RawMessage(keyJSONStr) + } + return result, rows.Err() +} + +func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + counts := &api.OneTimeKeysCount{ + DeviceID: deviceID, + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + +func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { + now := time.Now().Unix() + counts := &api.OneTimeKeysCount{ + DeviceID: keys.DeviceID, + UserID: keys.UserID, + KeyCount: make(map[string]int), + } + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( + ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), + ) + if err != nil { + return nil, err + } + } + rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + + return counts, rows.Err() +} + +func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( + ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} + +func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} diff --git a/userapi/storage/postgres/stale_device_lists.go b/userapi/storage/postgres/stale_device_lists.go new file mode 100644 index 00000000..c823b58c --- /dev/null +++ b/userapi/storage/postgres/stale_device_lists.go @@ -0,0 +1,131 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/lib/pq" + + "github.com/matrix-org/dendrite/internal/sqlutil" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" + +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)" + +type staleDeviceListsStatements struct { + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt + deleteStaleDeviceListsStmt *sql.Stmt +} + +func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{} + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, + }.Prepare(db) +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt) + _, err := stmt.ExecContext(ctx, pq.Array(userIDs)) + return err +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 92dc4808..673d123b 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -136,3 +136,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) + if err != nil { + return nil, err + } + otk, err := NewPostgresOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewPostgresDeviceKeysTable(db) + if err != nil { + return nil, err + } + kc, err := NewPostgresKeyChangesTable(db) + if err != nil { + return nil, err + } + sdl, err := NewPostgresStaleDeviceListsTable(db) + if err != nil { + return nil, err + } + csk, err := NewPostgresCrossSigningKeysTable(db) + if err != nil { + return nil, err + } + css, err := NewPostgresCrossSigningSigsTable(db) + if err != nil { + return nil, err + } + + return &shared.KeyDatabase{ + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, + CrossSigningKeysTable: csk, + CrossSigningSigsTable: css, + Writer: writer, + }, nil +} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index bf94f14d..d3272a03 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -59,6 +59,17 @@ type Database struct { OpenIDTokenLifetimeMS int64 } +type KeyDatabase struct { + OneTimeKeysTable tables.OneTimeKeys + DeviceKeysTable tables.DeviceKeys + KeyChangesTable tables.KeyChanges + StaleDeviceListsTable tables.StaleDeviceLists + CrossSigningKeysTable tables.CrossSigningKeys + CrossSigningSigsTable tables.CrossSigningSigs + DB *sql.DB + Writer sqlutil.Writer +} + const ( // The length of generated device IDs deviceIDByteLength = 6 @@ -875,3 +886,227 @@ func (d *Database) DailyRoomsMessages( ) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { return d.Stats.DailyRoomsMessages(ctx, nil, serverName) } + +// + +func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) +} + +func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) + return err + }) + return +} + +func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) +} + +func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) +} + +func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { + count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) + if err != nil { + return false, err + } + return count == len(prevIDs), nil +} + +func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for _, userID := range clearUserIDs { + err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) + if err != nil { + return err + } + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + // work out the latest stream IDs for each user + userIDToStreamID := make(map[string]int64) + for _, k := range keys { + userIDToStreamID[k.UserID] = 0 + } + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for userID := range userIDToStreamID { + streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) + if err != nil { + return err + } + userIDToStreamID[userID] = streamID + } + // set the stream IDs for each key + for i := range keys { + k := keys[i] + userIDToStreamID[k.UserID]++ // start stream from 1 + k.StreamID = userIDToStreamID[k.UserID] + keys[i] = k + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) +} + +func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { + var result []api.OneTimeKeys + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for userID, deviceToAlgo := range userToDeviceToAlgorithm { + for deviceID, algo := range deviceToAlgo { + keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo) + if err != nil { + return err + } + if keyJSON != nil { + result = append(result, api.OneTimeKeys{ + UserID: userID, + DeviceID: deviceID, + KeyJSON: keyJSON, + }) + } + } + } + return nil + }) + return result, err +} + +func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { + err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) + return err + }) + return +} + +func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { + return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) +} + +// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. +// If no domains are given, all user IDs with stale device lists are returned. +func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) +} + +// MarkDeviceListStale sets the stale bit for this user to isStale. +func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) + }) +} + +// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying +// cross-signing signatures relating to that device. +func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for _, deviceID := range deviceIDs { + if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) + } + if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) + } + if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err) + } + } + return nil + }) +} + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { + keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) + if err != nil { + return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) + } + results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} + for purpose, key := range keyMap { + keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) + result := gomatrixserverlib.CrossSigningKey{ + UserID: userID, + Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose}, + Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + keyID: key, + }, + } + sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID) + if err != nil { + continue + } + for sigUserID, forSigUserID := range sigMap { + if userID != sigUserID { + continue + } + if result.Signatures == nil { + result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + if _, ok := result.Signatures[sigUserID]; !ok { + result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + for sigKeyID, sigBytes := range forSigUserID { + result.Signatures[sigUserID][sigKeyID] = sigBytes + } + } + results[purpose] = result + } + return results, nil +} + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { + return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) +} + +// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. +func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { + return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) +} + +// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. +func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for keyType, keyData := range keyMap { + if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { + return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) + } + } + return nil + }) +} + +// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. +func (d *KeyDatabase) StoreCrossSigningSigsForTarget( + ctx context.Context, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err) + } + return nil + }) +} + +// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. +func (d *KeyDatabase) DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) + }) +} diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go new file mode 100644 index 00000000..10721fcc --- /dev/null +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -0,0 +1,101 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var crossSigningKeysSchema = ` +CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( + user_id TEXT NOT NULL, + key_type INTEGER NOT NULL, + key_data TEXT NOT NULL, + PRIMARY KEY (user_id, key_type) +); +` + +const selectCrossSigningKeysForUserSQL = "" + + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + + " WHERE user_id = $1" + +const upsertCrossSigningKeysForUserSQL = "" + + "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + + " VALUES($1, $2, $3)" + +type crossSigningKeysStatements struct { + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt +} + +func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { + s := &crossSigningKeysStatements{ + db: db, + } + _, err := db.Exec(crossSigningKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, + {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, + }.Prepare(db) +} + +func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, +) (r types.CrossSigningKeyMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") + r = types.CrossSigningKeyMap{} + for rows.Next() { + var keyTypeInt int16 + var keyData gomatrixserverlib.Base64Bytes + if err := rows.Scan(&keyTypeInt, &keyData); err != nil { + return nil, err + } + keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] + if !ok { + return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) + } + r[keyType] = keyData + } + return +} + +func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, +) error { + keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] + if !ok { + return fmt.Errorf("unknown key purpose %q", keyType) + } + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { + return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/cross_signing_sigs_table.go b/userapi/storage/sqlite3/cross_signing_sigs_table.go new file mode 100644 index 00000000..2be00c9c --- /dev/null +++ b/userapi/storage/sqlite3/cross_signing_sigs_table.go @@ -0,0 +1,129 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var crossSigningSigsSchema = ` +CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) +); + +CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); +` + +const selectCrossSigningSigsForTargetSQL = "" + + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + + " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4" + +const upsertCrossSigningSigsForTargetSQL = "" + + "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + + " VALUES($1, $2, $3, $4, $5)" + +const deleteCrossSigningSigsForTargetSQL = "" + + "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" + +type crossSigningSigsStatements struct { + db *sql.DB + selectCrossSigningSigsForTargetStmt *sql.Stmt + upsertCrossSigningSigsForTargetStmt *sql.Stmt + deleteCrossSigningSigsForTargetStmt *sql.Stmt +} + +func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { + s := &crossSigningSigsStatements{ + db: db, + } + _, err := db.Exec(crossSigningSigsSchema) + if err != nil { + return nil, err + } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "keyserver: cross signing signature indexes", + Up: deltas.UpFixCrossSigningSignatureIndexes, + }) + if err = m.Up(context.Background()); err != nil { + return nil, err + } + + return s, sqlutil.StatementList{ + {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, + {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, + {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL}, + }.Prepare(db) +} + +func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) (r types.CrossSigningSigMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForOriginTargetStmt: rows.close() failed") + r = types.CrossSigningSigMap{} + for rows.Next() { + var userID string + var keyID gomatrixserverlib.KeyID + var signature gomatrixserverlib.Base64Bytes + if err := rows.Scan(&userID, &keyID, &signature); err != nil { + return nil, err + } + if _, ok := r[userID]; !ok { + r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + r[userID][keyID] = signature + } + return +} + +func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} + +func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) error { + if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { + return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go new file mode 100644 index 00000000..cd0f19df --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go @@ -0,0 +1,66 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { + // start counting from the last max offset, else 0. + var maxOffset int64 + var userID string + _ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset) + + _, err := tx.ExecContext(ctx, ` + -- make the new table + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (user_id) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + // to start counting from maxOffset, insert a row with that value + if userID != "" { + _, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID) + return err + } + return nil +} + +func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + offset BIGINT NOT NULL, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (partition, offset) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go new file mode 100644 index 00000000..d4e38dea --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go @@ -0,0 +1,71 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) + ); + + INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature) + SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs; + + DROP TABLE keyserver_cross_signing_sigs; + ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs; + + CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, target_user_id, target_key_id) + ); + + INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature) + SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs; + + DROP TABLE keyserver_cross_signing_sigs; + ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs; + + DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx; + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/device_keys_table.go b/userapi/storage/sqlite3/device_keys_table.go new file mode 100644 index 00000000..15e69cc4 --- /dev/null +++ b/userapi/storage/sqlite3/device_keys_table.go @@ -0,0 +1,213 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var deviceKeysSchema = ` +-- Stores device keys for users +CREATE TABLE IF NOT EXISTS keyserver_device_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + stream_id BIGINT NOT NULL, + display_name TEXT, + -- Clobber based on tuple of user/device. + UNIQUE (user_id, device_id) +); +` + +const upsertDeviceKeysSQL = "" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT (user_id, device_id)" + + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" + +const selectDeviceKeysSQL = "" + + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +const selectBatchDeviceKeysSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" + +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" + +const countStreamIDsForUserSQL = "" + + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" + +const deleteDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +const deleteAllDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1" + +type deviceKeysStatements struct { + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt +} + +func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { + s := &deviceKeysStatements{ + db: db, + } + _, err := db.Exec(deviceKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL}, + {&s.selectDeviceKeysStmt, selectDeviceKeysSQL}, + {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL}, + {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL}, + {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL}, + // {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, // prepared at runtime + {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL}, + {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL}, + }.Prepare(db) +} + +func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} + +func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + return err +} + +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + deviceIDMap := make(map[string]bool) + for _, d := range deviceIDs { + deviceIDMap[d] = true + } + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") + var result []api.DeviceMessage + var displayName sql.NullString + for rows.Next() { + dk := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + UserID: userID, + }, + } + if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dk.DisplayName = displayName.String + } + // include the key if we want all keys (no device) or it was asked + if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { + result = append(result, dk) + } + } + return result, rows.Err() +} + +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + for i, key := range keys { + var keyJSONStr string + var streamID int64 + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) + if err != nil && err != sql.ErrNoRows { + return err + } + // this will be '' when there is no device + keys[i].Type = api.TypeDeviceKeyUpdate + keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } + } + return nil +} + +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { + // nullable if there are no results + var nullStream sql.NullInt64 + err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int64 + } + return +} + +func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { + iStreamIDs := make([]interface{}, len(streamIDs)+1) + iStreamIDs[0] = userID + for i := range streamIDs { + iStreamIDs[i+1] = streamIDs[i] + } + query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) + // nullable if there are no results + var count sql.NullInt64 + err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) + if err != nil { + return 0, err + } + if count.Valid { + return int(count.Int64), nil + } + return 0, nil +} + +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, + ) + if err != nil { + return err + } + } + return nil +} diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 449e4549..65e17527 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -151,7 +151,7 @@ func (s *devicesStatements) InsertDevice( if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -159,7 +159,11 @@ func (s *devicesStatements) InsertDevice( LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, @@ -172,7 +176,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn * if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -180,7 +184,11 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn * LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) DeleteDevice( @@ -202,6 +210,7 @@ func (s *devicesStatements) DeleteDevices( if err != nil { return err } + defer internal.CloseAndLogIfError(ctx, prep, "DeleteDevices.StmtClose() failed") stmt := sqlutil.TxStmt(txn, prep) params := make([]interface{}, len(devices)+2) params[0] = localpart diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go index 7883ffb1..ed274631 100644 --- a/userapi/storage/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -52,7 +52,7 @@ const updateBackupKeySQL = "" + const countKeysSQL = "" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" -const selectKeysSQL = "" + +const selectBackupKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" @@ -83,7 +83,7 @@ func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, - {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysStmt, selectBackupKeysSQL}, {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, }.Prepare(db) diff --git a/userapi/storage/sqlite3/key_changes_table.go b/userapi/storage/sqlite3/key_changes_table.go new file mode 100644 index 00000000..923bb57e --- /dev/null +++ b/userapi/storage/sqlite3/key_changes_table.go @@ -0,0 +1,125 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var keyChangesSchema = ` +-- Stores key change information about users. Used to determine when to send updated device lists to clients. +CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (user_id) +); +` + +// Replace based on user ID. We don't care how many times the user's keys have changed, only that they +// have changed, hence we can just keep bumping the change ID for this user. +const upsertKeyChangeSQL = "" + + "INSERT OR REPLACE INTO keyserver_key_changes (user_id)" + + " VALUES ($1)" + + " RETURNING change_id" + +const selectKeyChangesSQL = "" + + "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" + +type keyChangesStatements struct { + db *sql.DB + upsertKeyChangeStmt *sql.Stmt + selectKeyChangesStmt *sql.Stmt +} + +func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { + s := &keyChangesStatements{ + db: db, + } + _, err := db.Exec(keyChangesSchema) + if err != nil { + return s, err + } + + if err = executeMigration(context.Background(), db); err != nil { + return nil, err + } + + return s, sqlutil.StatementList{ + {&s.upsertKeyChangeStmt, upsertKeyChangeSQL}, + {&s.selectKeyChangesStmt, selectKeyChangesSQL}, + }.Prepare(db) +} + +func executeMigration(ctx context.Context, db *sql.DB) error { + // TODO: Remove when we are sure we are not having goose artefacts in the db + // This forces an error, which indicates the migration is already applied, since the + // column partition was removed from the table + migrationName := "keyserver: refactor key changes" + + var cName string + err := db.QueryRowContext(ctx, `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'keyserver_key_changes' AND p.name = 'partition'`).Scan(&cName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed + if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { + return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) + } + return nil + } + return err + } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: migrationName, + Up: deltas.UpRefactorKeyChanges, + }) + return m.Up(ctx) +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { + err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) + return +} + +func (s *keyChangesStatements) SelectKeyChanges( + ctx context.Context, fromOffset, toOffset int64, +) (userIDs []string, latestOffset int64, err error) { + latestOffset = fromOffset + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") + for rows.Next() { + var userID string + var offset int64 + if err := rows.Scan(&userID, &offset); err != nil { + return nil, 0, err + } + if offset > latestOffset { + latestOffset = offset + } + userIDs = append(userIDs, userID) + } + return +} diff --git a/userapi/storage/sqlite3/one_time_keys_table.go b/userapi/storage/sqlite3/one_time_keys_table.go new file mode 100644 index 00000000..a992d399 --- /dev/null +++ b/userapi/storage/sqlite3/one_time_keys_table.go @@ -0,0 +1,208 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var oneTimeKeysSchema = ` +-- Stores one-time public keys for users +CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + key_id TEXT NOT NULL, + algorithm TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- Clobber based on 4-uple of user/device/key/algorithm. + UNIQUE (user_id, device_id, key_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id); +` + +const upsertKeysSQL = "" + + "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT (user_id, device_id, key_id, algorithm)" + + " DO UPDATE SET key_json = $6" + +const selectOneTimeKeysSQL = "" + + "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" + +const selectKeysCountSQL = "" + + "SELECT algorithm, COUNT(key_id) FROM " + + " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" + + " x GROUP BY algorithm" + +const deleteOneTimeKeySQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" + +const selectKeyByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" + +const deleteOneTimeKeysSQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" + +type oneTimeKeysStatements struct { + db *sql.DB + upsertKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysCountStmt *sql.Stmt + selectKeyByAlgorithmStmt *sql.Stmt + deleteOneTimeKeyStmt *sql.Stmt + deleteOneTimeKeysStmt *sql.Stmt +} + +func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { + s := &oneTimeKeysStatements{ + db: db, + } + _, err := db.Exec(oneTimeKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertKeysSQL}, + {&s.selectKeysStmt, selectOneTimeKeysSQL}, + {&s.selectKeysCountStmt, selectKeysCountSQL}, + {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL}, + {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL}, + {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL}, + }.Prepare(db) +} + +func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") + + wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) + for _, ka := range keyIDsWithAlgorithms { + wantSet[ka] = true + } + + result := make(map[string]json.RawMessage) + for rows.Next() { + var keyID string + var algorithm string + var keyJSONStr string + if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil { + return nil, err + } + keyIDWithAlgo := algorithm + ":" + keyID + if wantSet[keyIDWithAlgo] { + result[keyIDWithAlgo] = json.RawMessage(keyJSONStr) + } + } + return result, rows.Err() +} + +func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + counts := &api.OneTimeKeysCount{ + DeviceID: deviceID, + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + +func (s *oneTimeKeysStatements) InsertOneTimeKeys( + ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys, +) (*api.OneTimeKeysCount, error) { + now := time.Now().Unix() + counts := &api.OneTimeKeysCount{ + DeviceID: keys.DeviceID, + UserID: keys.UserID, + KeyCount: make(map[string]int), + } + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( + ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), + ) + if err != nil { + return nil, err + } + } + rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + + return counts, rows.Err() +} + +func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( + ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + if err != nil { + return nil, err + } + if keyJSON == "" { + return nil, nil + } + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} + +func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} diff --git a/userapi/storage/sqlite3/stale_device_lists.go b/userapi/storage/sqlite3/stale_device_lists.go new file mode 100644 index 00000000..f078fc99 --- /dev/null +++ b/userapi/storage/sqlite3/stale_device_lists.go @@ -0,0 +1,145 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" + +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)" + +type staleDeviceListsStatements struct { + db *sql.DB + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt + // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime +} + +func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{ + db: db, + } + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime + }.Prepare(db) +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) + stmt, err := s.db.Prepare(qry) + if err != nil { + return err + } + defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed") + stmt = sqlutil.TxStmt(txn, stmt) + + params := make([]any, len(userIDs)) + for i := range userIDs { + params[i] = userIDs[i] + } + + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go index a1365c94..72b3ba49 100644 --- a/userapi/storage/sqlite3/stats_table.go +++ b/userapi/storage/sqlite3/stats_table.go @@ -256,6 +256,7 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int if err != nil { return 0, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "allUsers.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) err = stmt.QueryRowContext(ctx, 1, 2, 3, 4, @@ -269,6 +270,7 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res if err != nil { return 0, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "nonBridgedUsers.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) err = stmt.QueryRowContext(ctx, 1, 2, 3, @@ -286,6 +288,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "registeredUserByType.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) registeredAfter := time.Now().AddDate(0, 0, -30) diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 85a1f706..0f3eeed1 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -30,8 +30,8 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) -// NewDatabase creates a new accounts and profiles database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { +// NewUserDatabase creates a new accounts and profiles database +func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) if err != nil { return nil, err @@ -134,3 +134,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) + if err != nil { + return nil, err + } + otk, err := NewSqliteOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewSqliteDeviceKeysTable(db) + if err != nil { + return nil, err + } + kc, err := NewSqliteKeyChangesTable(db) + if err != nil { + return nil, err + } + sdl, err := NewSqliteStaleDeviceListsTable(db) + if err != nil { + return nil, err + } + csk, err := NewSqliteCrossSigningKeysTable(db) + if err != nil { + return nil, err + } + css, err := NewSqliteCrossSigningSigsTable(db) + if err != nil { + return nil, err + } + + return &shared.KeyDatabase{ + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, + CrossSigningKeysTable: csk, + CrossSigningSigsTable: css, + Writer: writer, + }, nil +} diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go index 42221e75..0329fb46 100644 --- a/userapi/storage/storage.go +++ b/userapi/storage/storage.go @@ -29,15 +29,36 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3" ) -// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) +// NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) { +func NewUserDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + serverName gomatrixserverlib.ServerName, + bcryptCost int, + openIDTokenLifetimeMS int64, + loginTokenLifetime time.Duration, + serverNoticesLocalpart string, +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) default: return nil, fmt.Errorf("unexpected database type") } } + +// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme) +// and sets postgres connection parameters. +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewKeyDatabase(base, dbProperties) + case dbProperties.ConnectionString.IsPostgres(): + return postgres.NewKeyDatabase(base, dbProperties) + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 23aafff0..f52e7e17 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -4,9 +4,12 @@ import ( "context" "encoding/json" "fmt" + "reflect" + "sync" "testing" "time" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" @@ -29,14 +32,14 @@ var ( ctx = context.Background() ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { - t.Fatalf("NewUserAPIDatabase returned %s", err) + t.Fatalf("NewUserDatabase returned %s", err) } return db, func() { close() @@ -47,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun // Tests storing and getting account data func Test_AccountData(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -78,7 +81,7 @@ func Test_AccountData(t *testing.T) { // Tests the creation of accounts func Test_Accounts(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -158,7 +161,7 @@ func Test_Devices(t *testing.T) { accessToken := util.RandomString(16) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "") @@ -238,7 +241,7 @@ func Test_KeyBackup(t *testing.T) { room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() wantAuthData := json.RawMessage("my auth data") @@ -315,7 +318,7 @@ func Test_KeyBackup(t *testing.T) { func Test_LoginToken(t *testing.T) { alice := test.NewUser(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create a new token @@ -347,7 +350,7 @@ func Test_OpenID(t *testing.T) { token := util.RandomString(24) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS @@ -368,7 +371,7 @@ func Test_Profile(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create account, which also creates a profile @@ -417,7 +420,7 @@ func Test_Pusher(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() appID := util.RandomString(8) @@ -468,7 +471,7 @@ func Test_ThreePID(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() threePID := util.RandomString(8) medium := util.RandomString(8) @@ -507,7 +510,7 @@ func Test_Notification(t *testing.T) { room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // generate some dummy notifications for i := 0; i < 10; i++ { @@ -571,3 +574,184 @@ func Test_Notification(t *testing.T) { assert.Equal(t, int64(0), total) }) } + +func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { + base, close := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + if err != nil { + t.Fatalf("failed to create new database: %v", err) + } + return db, close +} + +func MustNotError(t *testing.T, err error) { + t.Helper() + if err == nil { + return + } + t.Fatalf("operation failed: %s", err) +} + +func TestKeyChanges(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + _, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDC { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) + } + if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +func TestKeyChangesNoDupes(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + if deviceChangeIDA == deviceChangeIDB { + t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) + } + deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeID { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) + } + if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +func TestKeyChangesUpperLimit(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + _, err = db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDB { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) + } + if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +var dbLock sync.Mutex +var deviceArray = []string{"AAA", "another_device"} + +// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, +// and that they are returned correctly when querying for device keys. +func TestDeviceKeysStreamIDGeneration(t *testing.T) { + var err error + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + alice := "@alice:TestDeviceKeysStreamIDGeneration" + bob := "@bob:TestDeviceKeysStreamIDGeneration" + msgs := []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 + }, + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: bob, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 as this is a different user + }, + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "another_device", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 2 as this is a 2nd device key + }, + } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + } + if msgs[1].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + } + if msgs[2].StreamID != 2 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + } + + // updating a device sets the next stream ID for that user + msgs = []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v2"}`), + }, + // StreamID: 3 + }, + } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 3 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + } + + dbLock.Lock() + defer dbLock.Unlock() + // Querying for device keys returns the latest stream IDs + msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false) + + if err != nil { + t.Fatalf("DeviceKeysForUser returned error: %s", err) + } + wantStreamIDs := map[string]int64{ + "AAA": 3, + "another_device": 2, + } + if len(msgs) != len(wantStreamIDs) { + t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) + } + for _, m := range msgs { + if m.StreamID != wantStreamIDs[m.DeviceID] { + t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) + } + } + }) +} diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go index 5d5d292e..163e3e17 100644 --- a/userapi/storage/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -32,10 +32,10 @@ func NewUserAPIDatabase( openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string, -) (Database, error) { +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 9221e571..693e7303 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -20,10 +20,10 @@ import ( "encoding/json" "time" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/types" ) @@ -145,3 +145,47 @@ const ( // uint32. AllNotifications NotificationFilter = (1 << 31) - 1 ) + +type OneTimeKeys interface { + SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. + // Returns an empty map if the key does not exist. + SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) + DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error +} + +type DeviceKeys interface { + SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error + SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) + CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) + DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error + DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error +} + +type KeyChanges interface { + InsertKeyChange(ctx context.Context, userID string) (int64, error) + // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. + // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. + SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) +} + +type StaleDeviceLists interface { + InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error + SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error +} + +type CrossSigningKeys interface { + SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) + UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error +} + +type CrossSigningSigs interface { + SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) + UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error +} diff --git a/userapi/storage/tables/stale_device_lists_test.go b/userapi/storage/tables/stale_device_lists_test.go new file mode 100644 index 00000000..b9bdafda --- /dev/null +++ b/userapi/storage/tables/stale_device_lists_test.go @@ -0,0 +1,94 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/userapi/storage/postgres" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, nil) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresStaleDeviceListsTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db) + } + if err != nil { + t.Fatalf("failed to create new table: %s", err) + } + return tab, close +} + +func TestStaleDeviceLists(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := "@charlie:localhost" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateTable(t, dbType) + defer closeDB() + + if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + + // Query one server + wantStaleUsers := []string{alice.ID, bob.ID} + gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Query all servers + wantStaleUsers = []string{alice.ID, bob.ID, charlie} + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Delete stale devices + deleteUsers := []string{alice.ID, bob.ID} + if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil { + t.Fatalf("failed to delete stale device lists: %s", err) + } + + // Verify we don't get anything back after deleting + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + + if gotCount := len(gotStaleUsers); gotCount > 0 { + t.Fatalf("expected no stale users, got %d", gotCount) + } + }) +} diff --git a/userapi/types/storage.go b/userapi/types/storage.go new file mode 100644 index 00000000..7fb90454 --- /dev/null +++ b/userapi/types/storage.go @@ -0,0 +1,50 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "math" + + "github.com/matrix-org/gomatrixserverlib" +) + +const ( + // OffsetNewest tells e.g. the database to get the most current data + OffsetNewest int64 = math.MaxInt64 + // OffsetOldest tells e.g. the database to get the oldest data + OffsetOldest int64 = 0 +) + +// KeyTypePurposeToInt maps a purpose to an integer, which is used in the +// database to reduce the amount of space taken up by this column. +var KeyTypePurposeToInt = map[gomatrixserverlib.CrossSigningKeyPurpose]int16{ + gomatrixserverlib.CrossSigningKeyPurposeMaster: 1, + gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: 2, + gomatrixserverlib.CrossSigningKeyPurposeUserSigning: 3, +} + +// KeyTypeIntToPurpose maps an integer to a purpose, which is used in the +// database to reduce the amount of space taken up by this column. +var KeyTypeIntToPurpose = map[int16]gomatrixserverlib.CrossSigningKeyPurpose{ + 1: gomatrixserverlib.CrossSigningKeyPurposeMaster, + 2: gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, + 3: gomatrixserverlib.CrossSigningKeyPurposeUserSigning, +} + +// Map of purpose -> public key +type CrossSigningKeyMap map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.Base64Bytes + +// Map of user ID -> key ID -> signature +type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes diff --git a/userapi/userapi.go b/userapi/userapi.go index 2dd81d75..826bd721 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -17,13 +17,11 @@ package userapi import ( "time" + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/sirupsen/logrus" - "github.com/matrix-org/dendrite/internal/pushgateway" - keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/consumers" @@ -33,16 +31,20 @@ import ( "github.com/matrix-org/dendrite/userapi/util" ) -// NewInternalAPI returns a concerete implementation of the internal API. Callers +// NewInternalAPI returns a concrete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - base *base.BaseDendrite, cfg *config.UserAPI, - appServices []config.ApplicationService, keyAPI keyapi.UserKeyAPI, - rsAPI rsapi.UserRoomserverAPI, pgClient pushgateway.Client, -) api.UserInternalAPI { + base *base.BaseDendrite, + rsAPI rsapi.UserRoomserverAPI, + fedClient fedsenderapi.KeyserverFederationAPI, +) *internal.UserInternalAPI { + cfg := &base.Cfg.UserAPI js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + appServices := base.Cfg.Derived.ApplicationServices - db, err := storage.NewUserAPIDatabase( + pgClient := base.PushGatewayHTTPClient() + + db, err := storage.NewUserDatabase( base, &cfg.AccountDatabase, cfg.Matrix.ServerName, @@ -55,6 +57,11 @@ func NewInternalAPI( logrus.WithError(err).Panicf("failed to connect to accounts db") } + keyDB, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to key db") + } + syncProducer := producers.NewSyncAPI( db, js, // TODO: user API should handle syncs for account data. Right now, @@ -64,17 +71,50 @@ func NewInternalAPI( cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData), cfg.Matrix.JetStream.Prefixed(jetstream.OutputNotificationData), ) + keyChangeProducer := &producers.KeyChange{ + Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), + JetStream: js, + DB: keyDB, + } userAPI := &internal.UserInternalAPI{ DB: db, + KeyDatabase: keyDB, SyncProducer: syncProducer, + KeyChangeProducer: keyChangeProducer, Config: cfg, AppServices: appServices, - KeyAPI: keyAPI, RSAPI: rsAPI, DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, PgClient: pgClient, - Cfg: cfg, + FedClient: fedClient, + } + + updater := internal.NewDeviceListUpdater(base.ProcessContext, keyDB, userAPI, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable + userAPI.Updater = updater + // Remove users which we don't share a room with anymore + if err := updater.CleanUp(); err != nil { + logrus.WithError(err).Error("failed to cleanup stale device lists") + } + + go func() { + if err := updater.Start(); err != nil { + logrus.WithError(err).Panicf("failed to start device list updater") + } + }() + + dlConsumer := consumers.NewDeviceListUpdateConsumer( + base.ProcessContext, cfg, js, updater, + ) + if err := dlConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start device list consumer") + } + + sigConsumer := consumers.NewSigningKeyUpdateConsumer( + base.ProcessContext, cfg, js, userAPI, + ) + if err := sigConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start signing key consumer") } receiptConsumer := consumers.NewOutputReceiptEventConsumer( diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 68d08c2f..08b1336b 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -21,7 +21,10 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/nats-io/nats.go" "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/setup/config" @@ -38,32 +41,55 @@ const ( type apiTestOpts struct { loginTokenLifetime time.Duration + serverName string } -func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) { +type dummyProducer struct{} + +func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) { + return &nats.PubAck{}, nil +} + +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.UserDatabase, func()) { if opts.loginTokenLifetime == 0 { opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + sName := serverName + if opts.serverName != "" { + sName = gomatrixserverlib.ServerName(opts.serverName) + } + accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), - }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") + }, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") if err != nil { t.Fatalf("failed to create account DB: %s", err) } + keyDB, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }) + if err != nil { + t.Fatalf("failed to create key DB: %s", err) + } + cfg := &config.UserAPI{ Matrix: &config.Global{ SigningIdentity: gomatrixserverlib.SigningIdentity{ - ServerName: serverName, + ServerName: sName, }, }, } + syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "") + keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: &dummyProducer{}} return &internal.UserInternalAPI{ - DB: accountDB, - Config: cfg, + DB: accountDB, + KeyDatabase: keyDB, + Config: cfg, + SyncProducer: syncProducer, + KeyChangeProducer: keyChangeProducer, }, accountDB, func() { close() baseclose() @@ -332,3 +358,292 @@ func TestQueryAccountByLocalpart(t *testing.T) { testCases(t, intAPI) }) } + +func TestAccountData(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + + testCases := []struct { + name string + inputData *api.InputAccountDataRequest + wantErr bool + }{ + { + name: "not a local user", + inputData: &api.InputAccountDataRequest{UserID: "@notlocal:example.com"}, + wantErr: true, + }, + { + name: "local user missing datatype", + inputData: &api.InputAccountDataRequest{UserID: alice.ID}, + wantErr: true, + }, + { + name: "missing json", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: nil}, + wantErr: true, + }, + { + name: "with json", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}")}, + }, + { + name: "room data", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}"), RoomID: "!dummy:test"}, + }, + { + name: "ignored users", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.ignored_user_list", AccountData: []byte("{}")}, + }, + { + name: "m.fully_read", + inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.fully_read", AccountData: []byte("{}")}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + defer close() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res := api.InputAccountDataResponse{} + err := intAPI.InputAccountData(ctx, tc.inputData, &res) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + + // query the data again and compare + queryRes := api.QueryAccountDataResponse{} + queryReq := api.QueryAccountDataRequest{ + UserID: tc.inputData.UserID, + DataType: tc.inputData.DataType, + RoomID: tc.inputData.RoomID, + } + err = intAPI.QueryAccountData(ctx, &queryReq, &queryRes) + if err != nil && !tc.wantErr { + t.Fatal(err) + } + // verify global data + if tc.inputData.RoomID == "" { + if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.GlobalAccountData[tc.inputData.DataType]) { + t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.GlobalAccountData[tc.inputData.DataType])) + } + } else { + // verify room data + if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]) { + t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType])) + } + } + }) + } + }) +} + +func TestDevices(t *testing.T) { + ctx := context.Background() + + dupeAccessToken := util.RandomString(8) + + displayName := "testing" + + creationTests := []struct { + name string + inputData *api.PerformDeviceCreationRequest + wantErr bool + wantNewDevID bool + }{ + { + name: "not a local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal"}, + wantErr: true, + }, + { + name: "implicit local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, DeviceDisplayName: &displayName}, + }, + { + name: "explicit local user", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + { + name: "dupe token - ok", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, + }, + { + name: "dupe token - not ok", + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true}, + wantErr: true, + }, + { + name: "test3 second device", // used to test deletion later + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + { + name: "test3 third device", // used to test deletion later + wantNewDevID: true, + inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true}, + }, + } + + deletionTests := []struct { + name string + inputData *api.PerformDeviceDeletionRequest + wantErr bool + wantDevices int + }{ + { + name: "deletion - not a local user", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test:notlocalhost"}, + wantErr: true, + }, + { + name: "deleting not existing devices should not error", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test", DeviceIDs: []string{"iDontExist"}}, + wantDevices: 1, + }, + { + name: "delete all devices", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test"}, + wantDevices: 0, + }, + { + name: "delete all devices", + inputData: &api.PerformDeviceDeletionRequest{UserID: "@test3:test"}, + wantDevices: 0, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + defer close() + + for _, tc := range creationTests { + t.Run(tc.name, func(t *testing.T) { + res := api.PerformDeviceCreationResponse{} + deviceID := util.RandomString(8) + tc.inputData.DeviceID = &deviceID + if tc.wantNewDevID { + tc.inputData.DeviceID = nil + } + err := intAPI.PerformDeviceCreation(ctx, tc.inputData, &res) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + if !res.DeviceCreated { + return + } + + queryDevicesRes := api.QueryDevicesResponse{} + queryDevicesReq := api.QueryDevicesRequest{UserID: res.Device.UserID} + if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil { + t.Fatal(err) + } + // We only want to verify one device + if len(queryDevicesRes.Devices) > 1 { + return + } + res.Device.AccessToken = "" + + // At this point, there should only be one device + if !reflect.DeepEqual(*res.Device, queryDevicesRes.Devices[0]) { + t.Fatalf("expected device to be\n%#v, got \n%#v", *res.Device, queryDevicesRes.Devices[0]) + } + + newDisplayName := "new name" + if tc.inputData.DeviceDisplayName == nil { + updateRes := api.PerformDeviceUpdateResponse{} + updateReq := api.PerformDeviceUpdateRequest{ + RequestingUserID: fmt.Sprintf("@%s:%s", tc.inputData.Localpart, "test"), + DeviceID: deviceID, + DisplayName: &newDisplayName, + } + + if err = intAPI.PerformDeviceUpdate(ctx, &updateReq, &updateRes); err != nil { + t.Fatal(err) + } + } + + queryDeviceInfosRes := api.QueryDeviceInfosResponse{} + queryDeviceInfosReq := api.QueryDeviceInfosRequest{DeviceIDs: []string{*tc.inputData.DeviceID}} + if err = intAPI.QueryDeviceInfos(ctx, &queryDeviceInfosReq, &queryDeviceInfosRes); err != nil { + t.Fatal(err) + } + gotDisplayName := queryDeviceInfosRes.DeviceInfo[*tc.inputData.DeviceID].DisplayName + if tc.inputData.DeviceDisplayName != nil { + wantDisplayName := *tc.inputData.DeviceDisplayName + if wantDisplayName != gotDisplayName { + t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName) + } + } else { + wantDisplayName := newDisplayName + if wantDisplayName != gotDisplayName { + t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName) + } + } + }) + } + + for _, tc := range deletionTests { + t.Run(tc.name, func(t *testing.T) { + delRes := api.PerformDeviceDeletionResponse{} + err := intAPI.PerformDeviceDeletion(ctx, tc.inputData, &delRes) + if tc.wantErr && err == nil { + t.Fatalf("expected an error, but got none") + } + if !tc.wantErr && err != nil { + t.Fatalf("expected no error, but got: %s", err) + } + if tc.wantErr { + return + } + + queryDevicesRes := api.QueryDevicesResponse{} + queryDevicesReq := api.QueryDevicesRequest{UserID: tc.inputData.UserID} + if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil { + t.Fatal(err) + } + + if len(queryDevicesRes.Devices) != tc.wantDevices { + t.Fatalf("expected %d devices, got %d", tc.wantDevices, len(queryDevicesRes.Devices)) + } + + }) + } + }) +} + +// Tests that the session ID of a device is not reused when reusing the same device ID. +func TestDeviceIDReuse(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + defer close() + + res := api.PerformDeviceCreationResponse{} + // create a first device + deviceID := util.RandomString(8) + req := api.PerformDeviceCreationRequest{Localpart: "alice", ServerName: "test", DeviceID: &deviceID, NoDeviceListUpdate: true} + err := intAPI.PerformDeviceCreation(ctx, &req, &res) + if err != nil { + t.Fatal(err) + } + + // Do the same request again, we expect a different sessionID + res2 := api.PerformDeviceCreationResponse{} + err = intAPI.PerformDeviceCreation(ctx, &req, &res2) + if err != nil { + t.Fatalf("expected no error, but got: %v", err) + } + + if res2.Device.SessionID == res.Device.SessionID { + t.Fatalf("expected a different session ID, but they are the same") + } + }) +} diff --git a/userapi/util/devices.go b/userapi/util/devices.go index c55fc799..31617d8c 100644 --- a/userapi/util/devices.go +++ b/userapi/util/devices.go @@ -19,7 +19,7 @@ type PusherDevice struct { } // GetPushDevices pushes to the configured devices of a local user. -func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { +func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) { pushers, err := db.GetPushers(ctx, localpart, serverName) if err != nil { return nil, fmt.Errorf("db.GetPushers: %w", err) diff --git a/userapi/util/notify.go b/userapi/util/notify.go index fc0ab39b..08d1371d 100644 --- a/userapi/util/notify.go +++ b/userapi/util/notify.go @@ -13,11 +13,11 @@ import ( ) // NotifyUserCountsAsync sends notifications to a local user's -// notification destinations. Database lookups run synchronously, but +// notification destinations. UserDatabase lookups run synchronously, but // a single goroutine is started when talking to the Push // gateways. There is no way to know when the background goroutine has // finished. -func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error { +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.UserDatabase) error { pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db) if err != nil { return err diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index f1d20259..421852d3 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -79,7 +79,7 @@ func TestNotifyUserCountsAsync(t *testing.T) { defer close() base, _, _ := testrig.Base(nil) defer base.Close() - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "test", bcrypt.MinCost, 0, 0, "") if err != nil { diff --git a/userapi/util/phonehomestats_test.go b/userapi/util/phonehomestats_test.go index 6e62210e..5f626b5b 100644 --- a/userapi/util/phonehomestats_test.go +++ b/userapi/util/phonehomestats_test.go @@ -21,7 +21,7 @@ func TestCollect(t *testing.T) { b, _, _ := testrig.Base(nil) connStr, closeDB := test.PrepareDBConnectionString(t, dbType) defer closeDB() - db, err := storage.NewUserAPIDatabase(b, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(b, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, 1000, 1000, "") if err != nil { |