diff options
Diffstat (limited to 'keyserver')
40 files changed, 0 insertions, 6522 deletions
diff --git a/keyserver/README.md b/keyserver/README.md deleted file mode 100644 index fd9f37d2..00000000 --- a/keyserver/README.md +++ /dev/null @@ -1,19 +0,0 @@ -## Key Server - -This is an internal component which manages E2E keys from clients. It handles all the [Key Management APIs](https://matrix.org/docs/spec/client_server/r0.6.1#key-management-api) with the exception of `/keys/changes` which is handled by Sync API. This component is designed to shard by user ID. - -Keys are uploaded and stored in this component, and key changes are emitted to a Kafka topic for downstream components such as Sync API. - -### Internal APIs -- `PerformUploadKeys` stores identity keys and one-time public keys for given user(s). -- `PerformClaimKeys` acquires one-time public keys for given user(s). This may involve outbound federation calls. -- `QueryKeys` returns identity keys for given user(s). This may involve outbound federation calls. This component may then cache federated identity keys to avoid repeatedly hitting remote servers. -- A topic which emits identity keys every time there is a change (addition or deletion). - -### Endpoint mappings -- Client API maps `/keys/upload` to `PerformUploadKeys`. -- Client API maps `/keys/query` to `QueryKeys`. -- Client API maps `/keys/claim` to `PerformClaimKeys`. -- Federation API maps `/user/keys/query` to `QueryKeys`. -- Federation API maps `/user/keys/claim` to `PerformClaimKeys`. -- Sync API maps `/keys/changes` to consuming from the Kafka topic. diff --git a/keyserver/api/api.go b/keyserver/api/api.go deleted file mode 100644 index 14fced3e..00000000 --- a/keyserver/api/api.go +++ /dev/null @@ -1,346 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package api - -import ( - "bytes" - "context" - "encoding/json" - "strings" - "time" - - "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/keyserver/types" - userapi "github.com/matrix-org/dendrite/userapi/api" -) - -type KeyInternalAPI interface { - SyncKeyAPI - ClientKeyAPI - FederationKeyAPI - UserKeyAPI - - // SetUserAPI assigns a user API to query when extracting device names. - SetUserAPI(i userapi.KeyserverUserAPI) -} - -// API functions required by the clientapi -type ClientKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error - PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error - // PerformClaimKeys claims one-time keys for use in pre-key messages - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error - PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error -} - -// API functions required by the userapi -type UserKeyAPI interface { - PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error - PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error -} - -// API functions required by the syncapi -type SyncKeyAPI interface { - QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error - QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error - PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error -} - -type FederationKeyAPI interface { - QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error - QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error - QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error - PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error - PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error -} - -// KeyError is returned if there was a problem performing/querying the server -type KeyError struct { - Err string `json:"error"` - IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE - IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM - IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM -} - -func (k *KeyError) Error() string { - return k.Err -} - -type DeviceMessageType int - -const ( - TypeDeviceKeyUpdate DeviceMessageType = iota - TypeCrossSigningUpdate -) - -// DeviceMessage represents the message produced into Kafka by the key server. -type DeviceMessage struct { - Type DeviceMessageType `json:"Type,omitempty"` - *DeviceKeys `json:"DeviceKeys,omitempty"` - *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` - // A monotonically increasing number which represents device changes for this user. - StreamID int64 - DeviceChangeID int64 -} - -// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log -type OutputCrossSigningKeyUpdate struct { - CrossSigningKeyUpdate `json:"signing_keys"` -} - -type CrossSigningKeyUpdate struct { - MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"` - SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"` - UserID string `json:"user_id"` -} - -// DeviceKeysEqual returns true if the device keys updates contain the -// same display name and key JSON. This will return false if either of -// the updates is not a device keys update, or if the user ID/device ID -// differ between the two. -func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool { - if m1.DeviceKeys == nil || m2.DeviceKeys == nil { - return false - } - if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID { - return false - } - if m1.DisplayName != m2.DisplayName { - return false // different display names - } - if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 { - return false // either is empty - } - return bytes.Equal(m1.KeyJSON, m2.KeyJSON) -} - -// DeviceKeys represents a set of device keys for a single device -// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload -type DeviceKeys struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // The device display name - DisplayName string - // The raw device key JSON - KeyJSON []byte -} - -// WithStreamID returns a copy of this device message with the given stream ID -func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage { - return DeviceMessage{ - DeviceKeys: k, - StreamID: streamID, - } -} - -// OneTimeKeys represents a set of one-time keys for a single device -// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload -type OneTimeKeys struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // A map of algorithm:key_id => key JSON - KeyJSON map[string]json.RawMessage -} - -// Split a key in KeyJSON into algorithm and key ID -func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) { - segments := strings.Split(keyIDWithAlgo, ":") - return segments[0], segments[1] -} - -// OneTimeKeysCount represents the counts of one-time keys for a single device -type OneTimeKeysCount struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // algorithm to count e.g: - // { - // "curve25519": 10, - // "signed_curve25519": 20 - // } - KeyCount map[string]int -} - -// PerformUploadKeysRequest is the request to PerformUploadKeys -type PerformUploadKeysRequest struct { - UserID string // Required - User performing the request - DeviceID string // Optional - Device performing the request, for fetching OTK count - DeviceKeys []DeviceKeys - OneTimeKeys []OneTimeKeys - // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update - // the display name for their respective device, and NOT to modify the keys. The key - // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths. - // Without this flag, requests to modify device display names would delete device keys. - OnlyDisplayNameUpdates bool -} - -// PerformUploadKeysResponse is the response to PerformUploadKeys -type PerformUploadKeysResponse struct { - // A fatal error when processing e.g database failures - Error *KeyError - // A map of user_id -> device_id -> Error for tracking failures. - KeyErrors map[string]map[string]*KeyError - OneTimeKeyCounts []OneTimeKeysCount -} - -// PerformDeleteKeysRequest asks the keyserver to forget about certain -// keys, and signatures related to those keys. -type PerformDeleteKeysRequest struct { - UserID string - KeyIDs []gomatrixserverlib.KeyID -} - -// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest. -type PerformDeleteKeysResponse struct { - Error *KeyError -} - -// KeyError sets a key error field on KeyErrors -func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) { - if r.KeyErrors[userID] == nil { - r.KeyErrors[userID] = make(map[string]*KeyError) - } - r.KeyErrors[userID][deviceID] = err -} - -type PerformClaimKeysRequest struct { - // Map of user_id to device_id to algorithm name - OneTimeKeys map[string]map[string]string - Timeout time.Duration -} - -type PerformClaimKeysResponse struct { - // Map of user_id to device_id to algorithm:key_id to key JSON - OneTimeKeys map[string]map[string]map[string]json.RawMessage - // Map of remote server domain to error JSON - Failures map[string]interface{} - // Set if there was a fatal error processing this action - Error *KeyError -} - -type PerformUploadDeviceKeysRequest struct { - gomatrixserverlib.CrossSigningKeys - // The user that uploaded the key, should be populated by the clientapi. - UserID string -} - -type PerformUploadDeviceKeysResponse struct { - Error *KeyError -} - -type PerformUploadDeviceSignaturesRequest struct { - Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice - // The user that uploaded the sig, should be populated by the clientapi. - UserID string -} - -type PerformUploadDeviceSignaturesResponse struct { - Error *KeyError -} - -type QueryKeysRequest struct { - // The user ID asking for the keys, e.g. if from a client API request. - // Will not be populated if the key request came from federation. - UserID string - // Maps user IDs to a list of devices - UserToDevices map[string][]string - Timeout time.Duration -} - -type QueryKeysResponse struct { - // Map of remote server domain to error JSON - Failures map[string]interface{} - // Map of user_id to device_id to device_key - DeviceKeys map[string]map[string]json.RawMessage - // Maps of user_id to cross signing key - MasterKeys map[string]gomatrixserverlib.CrossSigningKey - SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey - UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // Set if there was a fatal error processing this query - Error *KeyError -} - -type QueryKeyChangesRequest struct { - // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning - Offset int64 - // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. - // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing). - ToOffset int64 -} - -type QueryKeyChangesResponse struct { - // The set of users who have had their keys change. - UserIDs []string - // The latest offset represented in this response. - Offset int64 - // Set if there was a problem handling the request. - Error *KeyError -} - -type QueryOneTimeKeysRequest struct { - // The local user to query OTK counts for - UserID string - // The device to query OTK counts for - DeviceID string -} - -type QueryOneTimeKeysResponse struct { - // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84 - Count OneTimeKeysCount - Error *KeyError -} - -type QueryDeviceMessagesRequest struct { - UserID string -} - -type QueryDeviceMessagesResponse struct { - // The latest stream ID - StreamID int64 - Devices []DeviceMessage - Error *KeyError -} - -type QuerySignaturesRequest struct { - // A map of target user ID -> target key/device IDs to retrieve signatures for - TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"` -} - -type QuerySignaturesResponse struct { - // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures - Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap - // A map of target user ID -> cross-signing master key - MasterKeys map[string]gomatrixserverlib.CrossSigningKey - // A map of target user ID -> cross-signing self-signing key - SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // A map of target user ID -> cross-signing user-signing key - UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey - // The request error, if any - Error *KeyError -} - -type PerformMarkAsStaleRequest struct { - UserID string - Domain gomatrixserverlib.ServerName - DeviceID string -} diff --git a/keyserver/consumers/devicelistupdate.go b/keyserver/consumers/devicelistupdate.go deleted file mode 100644 index cd911f8c..00000000 --- a/keyserver/consumers/devicelistupdate.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package consumers - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" - - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" -) - -// DeviceListUpdateConsumer consumes device list updates that came in over federation. -type DeviceListUpdateConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - updater *internal.DeviceListUpdater - isLocalServerName func(gomatrixserverlib.ServerName) bool -} - -// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. -func NewDeviceListUpdateConsumer( - process *process.ProcessContext, - cfg *config.KeyServer, - js nats.JetStreamContext, - updater *internal.DeviceListUpdater, -) *DeviceListUpdateConsumer { - return &DeviceListUpdateConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), - updater: updater, - isLocalServerName: cfg.Matrix.IsLocalServerName, - } -} - -// Start consuming from key servers -func (t *DeviceListUpdateConsumer) Start() error { - return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, 1, - t.onMessage, nats.DeliverAll(), nats.ManualAck(), - ) -} - -// onMessage is called in response to a message received on the -// key change events topic from the key server. -func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { - msg := msgs[0] // Guaranteed to exist if onMessage is called - var m gomatrixserverlib.DeviceListUpdateEvent - if err := json.Unmarshal(msg.Data, &m); err != nil { - logrus.WithError(err).Errorf("Failed to read from device list update input topic") - return true - } - origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) - if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { - return true - } else if t.isLocalServerName(serverName) { - return true - } else if serverName != origin { - return true - } - - err := t.updater.Update(ctx, m) - if err != nil { - logrus.WithFields(logrus.Fields{ - "user_id": m.UserID, - "device_id": m.DeviceID, - "stream_id": m.StreamID, - "prev_id": m.PrevID, - }).WithError(err).Errorf("Failed to update device list") - return false - } - return true -} diff --git a/keyserver/consumers/signingkeyupdate.go b/keyserver/consumers/signingkeyupdate.go deleted file mode 100644 index bcceaad1..00000000 --- a/keyserver/consumers/signingkeyupdate.go +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package consumers - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" - - keyapi "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" -) - -// SigningKeyUpdateConsumer consumes signing key updates that came in over federation. -type SigningKeyUpdateConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - keyAPI *internal.KeyInternalAPI - cfg *config.KeyServer - isLocalServerName func(gomatrixserverlib.ServerName) bool -} - -// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. -func NewSigningKeyUpdateConsumer( - process *process.ProcessContext, - cfg *config.KeyServer, - js nats.JetStreamContext, - keyAPI *internal.KeyInternalAPI, -) *SigningKeyUpdateConsumer { - return &SigningKeyUpdateConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - keyAPI: keyAPI, - cfg: cfg, - isLocalServerName: cfg.Matrix.IsLocalServerName, - } -} - -// Start consuming from key servers -func (t *SigningKeyUpdateConsumer) Start() error { - return jetstream.JetStreamConsumer( - t.ctx, t.jetstream, t.topic, t.durable, 1, - t.onMessage, nats.DeliverAll(), nats.ManualAck(), - ) -} - -// onMessage is called in response to a message received on the -// signing key update events topic from the key server. -func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { - msg := msgs[0] // Guaranteed to exist if onMessage is called - var updatePayload keyapi.CrossSigningKeyUpdate - if err := json.Unmarshal(msg.Data, &updatePayload); err != nil { - logrus.WithError(err).Errorf("Failed to read from signing key update input topic") - return true - } - origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) - if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { - logrus.WithError(err).Error("failed to split user id") - return true - } else if t.isLocalServerName(serverName) { - logrus.Warn("dropping device key update from ourself") - return true - } else if serverName != origin { - logrus.Warnf("dropping device key update, %s != %s", serverName, origin) - return true - } - - keys := gomatrixserverlib.CrossSigningKeys{} - if updatePayload.MasterKey != nil { - keys.MasterKey = *updatePayload.MasterKey - } - if updatePayload.SelfSigningKey != nil { - keys.SelfSigningKey = *updatePayload.SelfSigningKey - } - uploadReq := &keyapi.PerformUploadDeviceKeysRequest{ - CrossSigningKeys: keys, - UserID: updatePayload.UserID, - } - uploadRes := &keyapi.PerformUploadDeviceKeysResponse{} - if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil { - logrus.WithError(err).Error("failed to upload device keys") - return false - } - if uploadRes.Error != nil { - logrus.WithError(uploadRes.Error).Error("failed to upload device keys") - return true - } - - return true -} diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go deleted file mode 100644 index 99859dff..00000000 --- a/keyserver/internal/cross_signing.go +++ /dev/null @@ -1,587 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "bytes" - "context" - "crypto/ed25519" - "database/sql" - "fmt" - "strings" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" - "golang.org/x/crypto/curve25519" -) - -func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpose gomatrixserverlib.CrossSigningKeyPurpose) error { - // Is there exactly one key? - if len(key.Keys) != 1 { - return fmt.Errorf("should contain exactly one key") - } - - // Does the key ID match the key value? Iterates exactly once - for keyID, keyData := range key.Keys { - b64 := keyData.Encode() - tokens := strings.Split(string(keyID), ":") - if len(tokens) != 2 { - return fmt.Errorf("key ID is incorrectly formatted") - } - if tokens[1] != b64 { - return fmt.Errorf("key ID isn't correct") - } - switch tokens[0] { - case "ed25519": - if len(keyData) != ed25519.PublicKeySize { - return fmt.Errorf("ed25519 key is not the correct length") - } - case "curve25519": - if len(keyData) != curve25519.PointSize { - return fmt.Errorf("curve25519 key is not the correct length") - } - default: - // We can't enforce the key length to be correct for an - // algorithm that we don't recognise, so instead we'll - // just make sure that it isn't incredibly excessive. - if l := len(keyData); l > 4096 { - return fmt.Errorf("unknown key type is too long (%d bytes)", l) - } - } - } - - // Check to see if the signatures make sense - for _, forOriginUser := range key.Signatures { - for originKeyID, originSignature := range forOriginUser { - switch strings.SplitN(string(originKeyID), ":", 1)[0] { - case "ed25519": - if len(originSignature) != ed25519.SignatureSize { - return fmt.Errorf("ed25519 signature is not the correct length") - } - case "curve25519": - return fmt.Errorf("curve25519 signatures are impossible") - default: - if l := len(originSignature); l > 4096 { - return fmt.Errorf("unknown signature type is too long (%d bytes)", l) - } - } - } - } - - // Does the key claim to be from the right user? - if userID != key.UserID { - return fmt.Errorf("key has a user ID mismatch") - } - - // Does the key contain the correct purpose? - useful := false - for _, usage := range key.Usage { - if usage == purpose { - useful = true - break - } - } - if !useful { - return fmt.Errorf("key does not contain correct usage purpose") - } - - return nil -} - -// nolint:gocyclo -func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { - // Find the keys to store. - byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} - toStore := types.CrossSigningKeyMap{} - hasMasterKey := false - - if len(req.MasterKey.Keys) > 0 { - if err := sanityCheckKey(req.MasterKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err != nil { - res.Error = &api.KeyError{ - Err: "Master key sanity check failed: " + err.Error(), - IsInvalidParam: true, - } - return nil - } - - byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey - for _, key := range req.MasterKey.Keys { // iterates once, see sanityCheckKey - toStore[gomatrixserverlib.CrossSigningKeyPurposeMaster] = key - } - hasMasterKey = true - } - - if len(req.SelfSigningKey.Keys) > 0 { - if err := sanityCheckKey(req.SelfSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err != nil { - res.Error = &api.KeyError{ - Err: "Self-signing key sanity check failed: " + err.Error(), - IsInvalidParam: true, - } - return nil - } - - byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey - for _, key := range req.SelfSigningKey.Keys { // iterates once, see sanityCheckKey - toStore[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = key - } - } - - if len(req.UserSigningKey.Keys) > 0 { - if err := sanityCheckKey(req.UserSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeUserSigning); err != nil { - res.Error = &api.KeyError{ - Err: "User-signing key sanity check failed: " + err.Error(), - IsInvalidParam: true, - } - return nil - } - - byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey - for _, key := range req.UserSigningKey.Keys { // iterates once, see sanityCheckKey - toStore[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = key - } - } - - // If there's nothing to do then stop here. - if len(toStore) == 0 { - res.Error = &api.KeyError{ - Err: "No keys were supplied in the request", - IsMissingParam: true, - } - return nil - } - - // We can't have a self-signing or user-signing key without a master - // key, so make sure we have one of those. We will also only actually do - // something if any of the specified keys in the request are different - // to what we've got in the database, to avoid generating key change - // notifications unnecessarily. - existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID) - if err != nil { - res.Error = &api.KeyError{ - Err: "Retrieving cross-signing keys from database failed: " + err.Error(), - } - return nil - } - - // If we still can't find a master key for the user then stop the upload. - // This satisfies the "Fails to upload self-signing key without master key" test. - if !hasMasterKey { - if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey { - res.Error = &api.KeyError{ - Err: "No master key was found", - IsMissingParam: true, - } - return nil - } - } - - // Check if anything actually changed compared to what we have in the database. - changed := false - for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{ - gomatrixserverlib.CrossSigningKeyPurposeMaster, - gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, - gomatrixserverlib.CrossSigningKeyPurposeUserSigning, - } { - old, gotOld := existingKeys[purpose] - new, gotNew := toStore[purpose] - if gotOld != gotNew { - // A new key purpose has been specified that we didn't know before, - // or one has been removed. - changed = true - break - } - if !bytes.Equal(old, new) { - // One of the existing keys for a purpose we already knew about has - // changed. - changed = true - break - } - } - if !changed { - return nil - } - - // Store the keys. - if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err), - } - return nil - } - - // Now upload any signatures that were included with the keys. - for _, key := range byPurpose { - var targetKeyID gomatrixserverlib.KeyID - for targetKey := range key.Keys { // iterates once, see sanityCheckKey - targetKeyID = targetKey - } - for sigUserID, forSigUserID := range key.Signatures { - if sigUserID != req.UserID { - continue - } - for sigKeyID, sigBytes := range forSigUserID { - if err := a.DB.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err), - } - return nil - } - } - } - } - - // Finally, generate a notification that we updated the keys. - update := api.CrossSigningKeyUpdate{ - UserID: req.UserID, - } - if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok { - update.MasterKey = &mk - } - if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok { - update.SelfSigningKey = &ssk - } - if update.MasterKey == nil && update.SelfSigningKey == nil { - return nil - } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), - } - return nil - } - return nil -} - -func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error { - // Before we do anything, we need the master and self-signing keys for this user. - // Then we can verify the signatures make sense. - queryReq := &api.QueryKeysRequest{ - UserID: req.UserID, - UserToDevices: map[string][]string{}, - } - queryRes := &api.QueryKeysResponse{} - for userID := range req.Signatures { - queryReq.UserToDevices[userID] = []string{} - } - _ = a.QueryKeys(ctx, queryReq, queryRes) - - selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} - otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} - - // Sort signatures into two groups: one where people have signed their own - // keys and one where people have signed someone elses - for userID, forUserID := range req.Signatures { - for keyID, keyOrDevice := range forUserID { - switch key := keyOrDevice.CrossSigningBody.(type) { - case *gomatrixserverlib.CrossSigningKey: - if key.UserID == req.UserID { - if _, ok := selfSignatures[userID]; !ok { - selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} - } - selfSignatures[userID][keyID] = keyOrDevice - } else { - if _, ok := otherSignatures[userID]; !ok { - otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} - } - otherSignatures[userID][keyID] = keyOrDevice - } - - case *gomatrixserverlib.DeviceKeys: - if key.UserID == req.UserID { - if _, ok := selfSignatures[userID]; !ok { - selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} - } - selfSignatures[userID][keyID] = keyOrDevice - } else { - if _, ok := otherSignatures[userID]; !ok { - otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{} - } - otherSignatures[userID][keyID] = keyOrDevice - } - - default: - continue - } - } - } - - if err := a.processSelfSignatures(ctx, selfSignatures); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.processSelfSignatures: %s", err), - } - return nil - } - - if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.processOtherSignatures: %s", err), - } - return nil - } - - // Finally, generate a notification that we updated the signatures. - for userID := range req.Signatures { - masterKey := queryRes.MasterKeys[userID] - selfSigningKey := queryRes.SelfSigningKeys[userID] - update := api.CrossSigningKeyUpdate{ - UserID: userID, - MasterKey: &masterKey, - SelfSigningKey: &selfSigningKey, - } - if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err), - } - return nil - } - } - return nil -} - -func (a *KeyInternalAPI) processSelfSignatures( - ctx context.Context, - signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, -) error { - // Here we will process: - // * The user signing their own devices using their self-signing key - // * The user signing their master key using one of their devices - - for targetUserID, forTargetUserID := range signatures { - for targetKeyID, signature := range forTargetUserID { - switch sig := signature.CrossSigningBody.(type) { - case *gomatrixserverlib.CrossSigningKey: - for keyID := range sig.Keys { - split := strings.SplitN(string(keyID), ":", 2) - if len(split) > 1 && gomatrixserverlib.KeyID(split[1]) == targetKeyID { - targetKeyID = keyID // contains the ed25519: or other scheme - break - } - } - for originUserID, forOriginUserID := range sig.Signatures { - for originKeyID, originSig := range forOriginUserID { - if err := a.DB.StoreCrossSigningSigsForTarget( - ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, - ); err != nil { - return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) - } - } - } - - case *gomatrixserverlib.DeviceKeys: - for originUserID, forOriginUserID := range sig.Signatures { - for originKeyID, originSig := range forOriginUserID { - if err := a.DB.StoreCrossSigningSigsForTarget( - ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig, - ); err != nil { - return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) - } - } - } - - default: - return fmt.Errorf("unexpected type assertion") - } - } - } - - return nil -} - -func (a *KeyInternalAPI) processOtherSignatures( - ctx context.Context, userID string, queryRes *api.QueryKeysResponse, - signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice, -) error { - // Here we will process: - // * A user signing someone else's master keys using their user-signing keys - - for targetUserID, forTargetUserID := range signatures { - for _, signature := range forTargetUserID { - switch sig := signature.CrossSigningBody.(type) { - case *gomatrixserverlib.CrossSigningKey: - // Find the local copy of the master key. We'll use this to be - // sure that the supplied stanza matches the key that we think it - // should be. - masterKey, ok := queryRes.MasterKeys[targetUserID] - if !ok { - return fmt.Errorf("failed to find master key for user %q", targetUserID) - } - - // For each key ID, write the signatures. Maybe there'll be more - // than one algorithm in the future so it's best not to focus on - // everything being ed25519:. - for targetKeyID, suppliedKeyData := range sig.Keys { - // The master key will be supplied in the request, but we should - // make sure that it matches what we think the master key should - // actually be. - localKeyData, lok := masterKey.Keys[targetKeyID] - if !lok { - return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID) - } else if !bytes.Equal(suppliedKeyData, localKeyData) { - return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID) - } - - // We only care about the signatures from the uploading user, so - // we will ignore anything that didn't originate from them. - userSigs, ok := sig.Signatures[userID] - if !ok { - return fmt.Errorf("there are no signatures on master key %q from uploading user %q", targetKeyID, userID) - } - - for originKeyID, originSig := range userSigs { - if err := a.DB.StoreCrossSigningSigsForTarget( - ctx, userID, originKeyID, targetUserID, targetKeyID, originSig, - ); err != nil { - return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err) - } - } - } - - default: - // Users should only be signing another person's master key, - // so if we're here, it's probably because it's actually a - // gomatrixserverlib.DeviceKeys, which doesn't make sense. - } - } - } - - return nil -} - -func (a *KeyInternalAPI) crossSigningKeysFromDatabase( - ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse, -) { - for targetUserID := range req.UserToDevices { - keys, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) - if err != nil { - logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID) - continue - } - - for keyType, key := range keys { - var keyID gomatrixserverlib.KeyID - for id := range key.Keys { - keyID = id - break - } - - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID) - if err != nil && err != sql.ErrNoRows { - logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID) - continue - } - - appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) { - if key.Signatures == nil { - key.Signatures = types.CrossSigningSigMap{} - } - if _, ok := key.Signatures[originUserID]; !ok { - key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes) - } - key.Signatures[originUserID][originKeyID] = signature - } - - for originUserID, forOrigin := range sigMap { - for originKeyID, signature := range forOrigin { - switch { - case req.UserID != "" && originUserID == req.UserID: - // Include signatures that we created - appendSignature(originUserID, originKeyID, signature) - case originUserID == targetUserID: - // Include signatures that were created by the person whose key - // we are processing - appendSignature(originUserID, originKeyID, signature) - } - } - } - - switch keyType { - case gomatrixserverlib.CrossSigningKeyPurposeMaster: - res.MasterKeys[targetUserID] = key - - case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: - res.SelfSigningKeys[targetUserID] = key - - case gomatrixserverlib.CrossSigningKeyPurposeUserSigning: - res.UserSigningKeys[targetUserID] = key - } - } - } -} - -func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error { - for targetUserID, forTargetUser := range req.TargetIDs { - keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID) - if err != nil && err != sql.ErrNoRows { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err), - } - continue - } - - for targetPurpose, targetKey := range keyMap { - switch targetPurpose { - case gomatrixserverlib.CrossSigningKeyPurposeMaster: - if res.MasterKeys == nil { - res.MasterKeys = map[string]gomatrixserverlib.CrossSigningKey{} - } - res.MasterKeys[targetUserID] = targetKey - - case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: - if res.SelfSigningKeys == nil { - res.SelfSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{} - } - res.SelfSigningKeys[targetUserID] = targetKey - - case gomatrixserverlib.CrossSigningKeyPurposeUserSigning: - if res.UserSigningKeys == nil { - res.UserSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{} - } - res.UserSigningKeys[targetUserID] = targetKey - } - } - - for _, targetKeyID := range forTargetUser { - // Get own signatures only. - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID) - if err != nil && err != sql.ErrNoRows { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err), - } - return nil - } - - for sourceUserID, forSourceUser := range sigMap { - for sourceKeyID, sourceSig := range forSourceUser { - if res.Signatures == nil { - res.Signatures = map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{} - } - if _, ok := res.Signatures[targetUserID]; !ok { - res.Signatures[targetUserID] = map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{} - } - if _, ok := res.Signatures[targetUserID][targetKeyID]; !ok { - res.Signatures[targetUserID][targetKeyID] = types.CrossSigningSigMap{} - } - if _, ok := res.Signatures[targetUserID][targetKeyID][sourceUserID]; !ok { - res.Signatures[targetUserID][targetKeyID][sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - res.Signatures[targetUserID][targetKeyID][sourceUserID][sourceKeyID] = sourceSig - } - } - } - } - return nil -} diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go deleted file mode 100644 index 1b00d1ee..00000000 --- a/keyserver/internal/device_list_update.go +++ /dev/null @@ -1,579 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "hash/fnv" - "net" - "sync" - "time" - - rsapi "github.com/matrix-org/dendrite/roomserver/api" - - "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" - - fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/setup/process" -) - -var ( - deviceListUpdateCount = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: "dendrite", - Subsystem: "keyserver", - Name: "device_list_update", - Help: "Number of times we have attempted to update device lists from this server", - }, - []string{"server"}, - ) -) - -const requestTimeout = time.Second * 30 - -func init() { - prometheus.MustRegister( - deviceListUpdateCount, - ) -} - -// DeviceListUpdater handles device list updates from remote servers. -// -// In the case where we have the prev_id for an update, the updater just stores the update (after acquiring a per-user lock). -// In the case where we do not have the prev_id for an update, the updater marks the user_id as stale and notifies -// a worker to get the latest device list for this user. Note: stream IDs are scoped per user so missing a prev_id -// for a (user, device) does not mean that DEVICE is outdated as the previous ID could be for a different device: -// we have to invalidate all devices for that user. Once the list has been fetched, the per-user lock is acquired and the -// updater stores the latest list along with the latest stream ID. -// -// On startup, the updater spins up N workers which are responsible for querying device keys from remote servers. -// Workers are scoped by homeserver domain, with one worker responsible for many domains, determined by hashing -// mod N the server name. Work is sent via a channel which just serves to "poke" the worker as the data is retrieved -// from the database (which allows us to batch requests to the same server). This has a number of desirable properties: -// - We guarantee only 1 in-flight /keys/query request per server at any time as there is exactly 1 worker responsible -// for that domain. -// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where -// we have many many servers) -// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers. -// -// The downsides are that: -// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free -// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts) -// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests -// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse -// than being stuck behind foo.bar -// -// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is -// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried. -type DeviceListUpdater struct { - process *process.ProcessContext - // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1 - // request to the remote server and race. - // TODO: Put in an LRU cache to bound growth - userIDToMutex map[string]*sync.Mutex - mu *sync.Mutex // protects UserIDToMutex - - db DeviceListUpdaterDatabase - api DeviceListUpdaterAPI - producer KeyChangeProducer - fedClient fedsenderapi.KeyserverFederationAPI - workerChans []chan gomatrixserverlib.ServerName - thisServer gomatrixserverlib.ServerName - - // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will - // block on or timeout via a select. - userIDToChan map[string]chan bool - userIDToChanMu *sync.Mutex - rsAPI rsapi.KeyserverRoomserverAPI -} - -// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. -// Useful for testing. -type DeviceListUpdaterDatabase interface { - // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. - // If no domains are given, all user IDs with stale device lists are returned. - StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) - - // MarkDeviceListStale sets the stale bit for this user to isStale. - MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error - - // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior - // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. - StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error - - // PrevIDsExists returns true if all prev IDs exist for this user. - PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) - - // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. - DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - - DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error -} - -type DeviceListUpdaterAPI interface { - PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error -} - -// KeyChangeProducer is the interface for producers.KeyChange useful for testing. -type KeyChangeProducer interface { - ProduceKeyChanges(keys []api.DeviceMessage) error -} - -// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. -func NewDeviceListUpdater( - process *process.ProcessContext, db DeviceListUpdaterDatabase, - api DeviceListUpdaterAPI, producer KeyChangeProducer, - fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, - rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName, -) *DeviceListUpdater { - return &DeviceListUpdater{ - process: process, - userIDToMutex: make(map[string]*sync.Mutex), - mu: &sync.Mutex{}, - db: db, - api: api, - producer: producer, - fedClient: fedClient, - thisServer: thisServer, - workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), - userIDToChan: make(map[string]chan bool), - userIDToChanMu: &sync.Mutex{}, - rsAPI: rsAPI, - } -} - -// Start the device list updater, which will try to refresh any stale device lists. -func (u *DeviceListUpdater) Start() error { - for i := 0; i < len(u.workerChans); i++ { - // Allocate a small buffer per channel. - // If the buffer limit is reached, backpressure will cause the processing of EDUs - // to stop (in this transaction) until key requests can be made. - ch := make(chan gomatrixserverlib.ServerName, 10) - u.workerChans[i] = ch - go u.worker(ch) - } - - staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) - if err != nil { - return err - } - offset, step := time.Second*10, time.Second - if max := len(staleLists); max > 120 { - step = (time.Second * 120) / time.Duration(max) - } - for _, userID := range staleLists { - userID := userID // otherwise we are only sending the last entry - time.AfterFunc(offset, func() { - u.notifyWorkers(userID) - }) - offset += step - } - return nil -} - -// CleanUp removes stale device entries for users we don't share a room with anymore -func (u *DeviceListUpdater) CleanUp() error { - staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) - if err != nil { - return err - } - - res := rsapi.QueryLeftUsersResponse{} - if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil { - return err - } - - if len(res.LeftUsers) == 0 { - return nil - } - logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers)) - return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers) -} - -func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { - u.mu.Lock() - defer u.mu.Unlock() - if u.userIDToMutex[userID] == nil { - u.userIDToMutex[userID] = &sync.Mutex{} - } - return u.userIDToMutex[userID] -} - -// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it. -// Blocks until the device list is synced or the timeout is reached. -func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error { - mu := u.mutex(userID) - mu.Lock() - err := u.db.MarkDeviceListStale(ctx, userID, true) - mu.Unlock() - if err != nil { - return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err) - } - u.notifyWorkers(userID) - return nil -} - -// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest, -// which assumes when /send 200 OKs that the device lists have been updated. -func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error { - isDeviceListStale, err := u.update(ctx, event) - if err != nil { - return err - } - if isDeviceListStale { - // poke workers to handle stale device lists - u.notifyWorkers(event.UserID) - } - return nil -} - -func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) (bool, error) { - mu := u.mutex(event.UserID) - mu.Lock() - defer mu.Unlock() - // check if we have the prev IDs - exists, err := u.db.PrevIDsExists(ctx, event.UserID, event.PrevID) - if err != nil { - return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err) - } - // if this is the first time we're hearing about this user, sync the device list manually. - if len(event.PrevID) == 0 { - exists = false - } - util.GetLogger(ctx).WithFields(logrus.Fields{ - "prev_ids_exist": exists, - "user_id": event.UserID, - "device_id": event.DeviceID, - "stream_id": event.StreamID, - "prev_ids": event.PrevID, - "display_name": event.DeviceDisplayName, - "deleted": event.Deleted, - }).Trace("DeviceListUpdater.Update") - - // if we haven't missed anything update the database and notify users - if exists || event.Deleted { - k := event.Keys - if event.Deleted { - k = nil - } - keys := []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: event.DeviceID, - DisplayName: event.DeviceDisplayName, - KeyJSON: k, - UserID: event.UserID, - }, - StreamID: event.StreamID, - }, - } - - // DeviceKeysJSON will side-effect modify this, so it needs - // to be a copy, not sharing any pointers with the above. - deviceKeysCopy := *keys[0].DeviceKeys - deviceKeysCopy.KeyJSON = nil - existingKeys := []api.DeviceMessage{ - { - Type: keys[0].Type, - DeviceKeys: &deviceKeysCopy, - StreamID: keys[0].StreamID, - }, - } - - // fetch what keys we had already and only emit changes - if err = u.db.DeviceKeysJSON(ctx, existingKeys); err != nil { - // non-fatal, log and continue - util.GetLogger(ctx).WithError(err).WithField("user_id", event.UserID).Errorf( - "failed to query device keys json for calculating diffs", - ) - } - - err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil) - if err != nil { - return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err) - } - - if err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false); err != nil { - return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err) - } - return false, nil - } - - err = u.db.MarkDeviceListStale(ctx, event.UserID, true) - if err != nil { - return false, fmt.Errorf("failed to mark device list for %s as stale: %w", event.UserID, err) - } - - return true, nil -} - -func (u *DeviceListUpdater) notifyWorkers(userID string) { - _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return - } - hash := fnv.New32a() - _, _ = hash.Write([]byte(remoteServer)) - index := int(int64(hash.Sum32()) % int64(len(u.workerChans))) - - ch := u.assignChannel(userID) - u.workerChans[index] <- remoteServer - select { - case <-ch: - case <-time.After(10 * time.Second): - // we don't return an error in this case as it's not a failure condition. - // we mainly block for the benefit of sytest anyway - } -} - -func (u *DeviceListUpdater) assignChannel(userID string) chan bool { - u.userIDToChanMu.Lock() - defer u.userIDToChanMu.Unlock() - if ch, ok := u.userIDToChan[userID]; ok { - return ch - } - ch := make(chan bool) - u.userIDToChan[userID] = ch - return ch -} - -func (u *DeviceListUpdater) clearChannel(userID string) { - u.userIDToChanMu.Lock() - defer u.userIDToChanMu.Unlock() - if ch, ok := u.userIDToChan[userID]; ok { - close(ch) - delete(u.userIDToChan, userID) - } -} - -func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { - retries := make(map[gomatrixserverlib.ServerName]time.Time) - retriesMu := &sync.Mutex{} - // restarter goroutine which will inject failed servers into ch when it is time - go func() { - var serversToRetry []gomatrixserverlib.ServerName - for { - serversToRetry = serversToRetry[:0] // reuse memory - time.Sleep(time.Second) - retriesMu.Lock() - now := time.Now() - for srv, retryAt := range retries { - if now.After(retryAt) { - serversToRetry = append(serversToRetry, srv) - } - } - for _, srv := range serversToRetry { - delete(retries, srv) - } - retriesMu.Unlock() - for _, srv := range serversToRetry { - ch <- srv - } - } - }() - for serverName := range ch { - retriesMu.Lock() - _, exists := retries[serverName] - retriesMu.Unlock() - if exists { - // Don't retry a server that we're already waiting for. - continue - } - waitTime, shouldRetry := u.processServer(serverName) - if shouldRetry { - retriesMu.Lock() - if _, exists = retries[serverName]; !exists { - retries[serverName] = time.Now().Add(waitTime) - } - retriesMu.Unlock() - } - } -} - -func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) { - ctx := u.process.Context() - logger := util.GetLogger(ctx).WithField("server_name", serverName) - deviceListUpdateCount.WithLabelValues(string(serverName)).Inc() - - waitTime := defaultWaitTime // How long should we wait to try again? - successCount := 0 // How many user requests failed? - - userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) - if err != nil { - logger.WithError(err).Error("Failed to load stale device lists") - return waitTime, true - } - - defer func() { - for _, userID := range userIDs { - // always clear the channel to unblock Update calls regardless of success/failure - u.clearChannel(userID) - } - }() - - for _, userID := range userIDs { - userWait, err := u.processServerUser(ctx, serverName, userID) - if err != nil { - if userWait > waitTime { - waitTime = userWait - } - break - } - successCount++ - } - - allUsersSucceeded := successCount == len(userIDs) - if !allUsersSucceeded { - logger.WithFields(logrus.Fields{ - "total": len(userIDs), - "succeeded": successCount, - "failed": len(userIDs) - successCount, - "wait_time": waitTime, - }).Debug("Failed to query device keys for some users") - } - return waitTime, !allUsersSucceeded -} - -func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) { - ctx, cancel := context.WithTimeout(ctx, requestTimeout) - defer cancel() - logger := util.GetLogger(ctx).WithFields(logrus.Fields{ - "server_name": serverName, - "user_id": userID, - }) - res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID) - if err != nil { - if errors.Is(err, context.DeadlineExceeded) { - return time.Minute * 10, err - } - switch e := err.(type) { - case *json.UnmarshalTypeError, *json.SyntaxError: - logger.WithError(err).Debugf("Device list update for %q contained invalid JSON", userID) - return defaultWaitTime, nil - case *fedsenderapi.FederationClientError: - if e.RetryAfter > 0 { - return e.RetryAfter, err - } else if e.Blacklisted { - return time.Hour * 8, err - } - case net.Error: - // Use the default waitTime, if it's a timeout. - // It probably doesn't make sense to try further users. - if !e.Timeout() { - logger.WithError(e).Debug("GetUserDevices returned net.Error") - return time.Minute * 10, err - } - case gomatrix.HTTPError: - // The remote server returned an error, give it some time to recover. - // This is to avoid spamming remote servers, which may not be Matrix servers anymore. - if e.Code >= 300 { - logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError") - return hourWaitTime, err - } - default: - // Something else failed - logger.WithError(err).Debugf("GetUserDevices returned unknown error type: %T", err) - return time.Minute * 10, err - } - } - if res.UserID != userID { - logger.WithError(err).Debugf("User ID %q in device list update response doesn't match expected %q", res.UserID, userID) - return defaultWaitTime, nil - } - if res.MasterKey != nil || res.SelfSigningKey != nil { - uploadReq := &api.PerformUploadDeviceKeysRequest{ - UserID: userID, - } - uploadRes := &api.PerformUploadDeviceKeysResponse{} - if res.MasterKey != nil { - if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil { - uploadReq.MasterKey = *res.MasterKey - } - } - if res.SelfSigningKey != nil { - if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil { - uploadReq.SelfSigningKey = *res.SelfSigningKey - } - } - _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes) - } - err = u.updateDeviceList(&res) - if err != nil { - logger.WithError(err).Error("Fetched device list but failed to store/emit it") - return defaultWaitTime, err - } - return defaultWaitTime, nil -} - -func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error { - ctx := context.Background() // we've got the keys, don't time out when persisting them to the database. - keys := make([]api.DeviceMessage, len(res.Devices)) - existingKeys := make([]api.DeviceMessage, len(res.Devices)) - for i, device := range res.Devices { - keyJSON, err := json.Marshal(device.Keys) - if err != nil { - util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device") - continue - } - keys[i] = api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - StreamID: res.StreamID, - DeviceKeys: &api.DeviceKeys{ - DeviceID: device.DeviceID, - DisplayName: device.DisplayName, - UserID: res.UserID, - KeyJSON: keyJSON, - }, - } - existingKeys[i] = api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - UserID: res.UserID, - DeviceID: device.DeviceID, - }, - } - } - // fetch what keys we had already and only emit changes - if err := u.db.DeviceKeysJSON(ctx, existingKeys); err != nil { - // non-fatal, log and continue - util.GetLogger(ctx).WithError(err).WithField("user_id", res.UserID).Errorf( - "failed to query device keys json for calculating diffs", - ) - } - - err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID}) - if err != nil { - return fmt.Errorf("failed to store remote device keys: %w", err) - } - err = u.db.MarkDeviceListStale(ctx, res.UserID, false) - if err != nil { - return fmt.Errorf("failed to mark device list as fresh: %w", err) - } - err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false) - if err != nil { - return fmt.Errorf("failed to emit key changes for fresh device list: %w", err) - } - return nil -} diff --git a/keyserver/internal/device_list_update_default.go b/keyserver/internal/device_list_update_default.go deleted file mode 100644 index 7d357c95..00000000 --- a/keyserver/internal/device_list_update_default.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !vw - -package internal - -import "time" - -const defaultWaitTime = time.Minute -const hourWaitTime = time.Hour diff --git a/keyserver/internal/device_list_update_sytest.go b/keyserver/internal/device_list_update_sytest.go deleted file mode 100644 index 1c60d2eb..00000000 --- a/keyserver/internal/device_list_update_sytest.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build vw - -package internal - -import "time" - -// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite -// results in a one-hour wait time from a previous device so the test times out. This is fine for -// production, but makes an otherwise passing test fail. -const defaultWaitTime = time.Second -const hourWaitTime = time.Second diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go deleted file mode 100644 index 60a2c2f3..00000000 --- a/keyserver/internal/device_list_update_test.go +++ /dev/null @@ -1,431 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "context" - "crypto/ed25519" - "fmt" - "io" - "net/http" - "net/url" - "reflect" - "strings" - "sync" - "testing" - "time" - - "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" - roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" -) - -var ( - ctx = context.Background() -) - -type mockKeyChangeProducer struct { - events []api.DeviceMessage -} - -func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error { - p.events = append(p.events, keys...) - return nil -} - -type mockDeviceListUpdaterDatabase struct { - staleUsers map[string]bool - prevIDsExist func(string, []int64) bool - storedKeys []api.DeviceMessage - mu sync.Mutex // protect staleUsers -} - -func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error { - return nil -} - -// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. -// If no domains are given, all user IDs with stale device lists are returned. -func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - d.mu.Lock() - defer d.mu.Unlock() - var result []string - for userID, isStale := range d.staleUsers { - if !isStale { - continue - } - _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return nil, err - } - if len(domains) == 0 { - result = append(result, userID) - continue - } - for _, d := range domains { - if remoteServer == d { - result = append(result, userID) - break - } - } - } - return result, nil -} - -// MarkDeviceListStale sets the stale bit for this user to isStale. -func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { - d.mu.Lock() - defer d.mu.Unlock() - d.staleUsers[userID] = isStale - return nil -} - -func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool { - d.mu.Lock() - defer d.mu.Unlock() - return d.staleUsers[userID] -} - -// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key -// for this (user, device). Does not modify the stream ID for keys. -func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error { - d.storedKeys = append(d.storedKeys, keys...) - return nil -} - -// PrevIDsExists returns true if all prev IDs exist for this user. -func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { - return d.prevIDsExist(userID, prevIDs), nil -} - -func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { - return nil -} - -type mockDeviceListUpdaterAPI struct { -} - -func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error { - return nil -} - -type roundTripper struct { - fn func(*http.Request) (*http.Response, error) -} - -func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - return t.fn(req) -} - -func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { - _, pkey, _ := ed25519.GenerateKey(nil) - fedClient := gomatrixserverlib.NewFederationClient( - []*gomatrixserverlib.SigningIdentity{ - { - ServerName: gomatrixserverlib.ServerName("example.test"), - KeyID: gomatrixserverlib.KeyID("ed25519:test"), - PrivateKey: pkey, - }, - }, - ) - fedClient.Client = *gomatrixserverlib.NewClient( - gomatrixserverlib.WithTransport(&roundTripper{tripper}), - ) - return fedClient -} - -// Test that the device keys get persisted and emitted if we have the previous IDs. -func TestUpdateHavePrevID(t *testing.T) { - db := &mockDeviceListUpdaterDatabase{ - staleUsers: make(map[string]bool), - prevIDsExist: func(string, []int64) bool { - return true - }, - } - ap := &mockDeviceListUpdaterAPI{} - producer := &mockKeyChangeProducer{} - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost") - event := gomatrixserverlib.DeviceListUpdateEvent{ - DeviceDisplayName: "Foo Bar", - Deleted: false, - DeviceID: "FOO", - Keys: []byte(`{"key":"value"}`), - PrevID: []int64{0}, - StreamID: 1, - UserID: "@alice:localhost", - } - err := updater.Update(ctx, event) - if err != nil { - t.Fatalf("Update returned an error: %s", err) - } - want := api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - StreamID: event.StreamID, - DeviceKeys: &api.DeviceKeys{ - DeviceID: event.DeviceID, - DisplayName: event.DeviceDisplayName, - KeyJSON: event.Keys, - UserID: event.UserID, - }, - } - if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { - t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) - } - if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { - t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) - } - if db.isStale(event.UserID) { - t.Errorf("%s incorrectly marked as stale", event.UserID) - } -} - -// Test that device keys are fetched from the remote server if we are missing prev IDs -// and that the user's devices are marked as stale until it succeeds. -func TestUpdateNoPrevID(t *testing.T) { - db := &mockDeviceListUpdaterDatabase{ - staleUsers: make(map[string]bool), - prevIDsExist: func(string, []int64) bool { - return false - }, - } - ap := &mockDeviceListUpdaterAPI{} - producer := &mockKeyChangeProducer{} - remoteUserID := "@alice:example.somewhere" - var wg sync.WaitGroup - wg.Add(1) - keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` - fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { - defer wg.Done() - if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) { - return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) - } - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(` - { - "user_id": "` + remoteUserID + `", - "stream_id": 5, - "devices": [ - { - "device_id": "JLAFKJWSCS", - "keys": ` + keyJSON + `, - "device_display_name": "Mobile Phone" - } - ] - } - `)), - }, nil - }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test") - if err := updater.Start(); err != nil { - t.Fatalf("failed to start updater: %s", err) - } - event := gomatrixserverlib.DeviceListUpdateEvent{ - DeviceDisplayName: "Mobile Phone", - Deleted: false, - DeviceID: "another_device_id", - Keys: []byte(`{"key":"value"}`), - PrevID: []int64{3}, - StreamID: 4, - UserID: remoteUserID, - } - err := updater.Update(ctx, event) - - if err != nil { - t.Fatalf("Update returned an error: %s", err) - } - t.Log("waiting for /users/devices to be called...") - wg.Wait() - // wait a bit for db to be updated... - time.Sleep(100 * time.Millisecond) - want := api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - StreamID: 5, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "JLAFKJWSCS", - DisplayName: "Mobile Phone", - UserID: remoteUserID, - KeyJSON: []byte(keyJSON), - }, - } - // Now we should have a fresh list and the keys and emitted something - if db.isStale(event.UserID) { - t.Errorf("%s still marked as stale", event.UserID) - } - if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { - t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON)) - t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) - } - if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { - t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) - } - -} - -// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the -// update is still ongoing. -func TestDebounce(t *testing.T) { - t.Skipf("panic on closed channel on GHA") - db := &mockDeviceListUpdaterDatabase{ - staleUsers: make(map[string]bool), - prevIDsExist: func(string, []int64) bool { - return true - }, - } - ap := &mockDeviceListUpdaterAPI{} - producer := &mockKeyChangeProducer{} - fedCh := make(chan *http.Response, 1) - srv := gomatrixserverlib.ServerName("example.com") - userID := "@alice:example.com" - keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}` - incomingFedReq := make(chan struct{}) - fedClient := newFedClient(func(req *http.Request) (*http.Response, error) { - if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) { - return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path) - } - close(incomingFedReq) - return <-fedCh, nil - }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost") - if err := updater.Start(); err != nil { - t.Fatalf("failed to start updater: %s", err) - } - - // hit this 5 times - var wg sync.WaitGroup - wg.Add(5) - for i := 0; i < 5; i++ { - go func() { - defer wg.Done() - if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil { - t.Errorf("ManualUpdate: %s", err) - } - }() - } - - // wait until the updater hits federation - select { - case <-incomingFedReq: - case <-time.After(time.Second): - t.Fatalf("timed out waiting for updater to hit federation") - } - - // user should be marked as stale - if !db.isStale(userID) { - t.Errorf("user %s not marked as stale", userID) - } - // now send the response over federation - fedCh <- &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(` - { - "user_id": "` + userID + `", - "stream_id": 5, - "devices": [ - { - "device_id": "JLAFKJWSCS", - "keys": ` + keyJSON + `, - "device_display_name": "Mobile Phone" - } - ] - } - `)), - } - close(fedCh) - // wait until all 5 ManualUpdates return. If we hit federation again we won't send a response - // and should panic with read on a closed channel - wg.Wait() - - // user is no longer stale now - if db.isStale(userID) { - t.Errorf("user %s is marked as stale", userID) - } -} - -func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) { - t.Helper() - - base, _, _ := testrig.Base(nil) - connStr, clearDB := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) - if err != nil { - t.Fatal(err) - } - - return db, clearDB -} - -type mockKeyserverRoomserverAPI struct { - leftUsers []string -} - -func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { - res.LeftUsers = m.leftUsers - return nil -} - -func TestDeviceListUpdater_CleanUp(t *testing.T) { - processCtx := process.NewProcessContext() - - alice := test.NewUser(t) - bob := test.NewUser(t) - - // Bob is not joined to any of our rooms - rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}} - - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clearDB := mustCreateKeyserverDB(t, dbType) - defer clearDB() - - // This should not get deleted - if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil { - t.Error(err) - } - - // this one should get deleted - if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil { - t.Error(err) - } - - updater := NewDeviceListUpdater(processCtx, db, nil, - nil, nil, - 0, rsAPI, "test") - if err := updater.CleanUp(); err != nil { - t.Error(err) - } - - // check that we still have Alice in our stale list - staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) - if err != nil { - t.Error(err) - } - - // There should only be Alice - wantCount := 1 - if count := len(staleUsers); count != wantCount { - t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count) - } - - if staleUsers[0] != alice.ID { - t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID) - } - }) -} diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go deleted file mode 100644 index 9a08a0bb..00000000 --- a/keyserver/internal/internal.go +++ /dev/null @@ -1,816 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "sync" - "time" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - - fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/producers" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/config" - userapi "github.com/matrix-org/dendrite/userapi/api" -) - -type KeyInternalAPI struct { - DB storage.Database - Cfg *config.KeyServer - FedClient fedsenderapi.KeyserverFederationAPI - UserAPI userapi.KeyserverUserAPI - Producer *producers.KeyChange - Updater *DeviceListUpdater -} - -func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { - a.UserAPI = i -} - -func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error { - userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset) - if err != nil { - res.Error = &api.KeyError{ - Err: err.Error(), - } - return nil - } - res.Offset = latest - res.UserIDs = userIDs - return nil -} - -func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error { - res.KeyErrors = make(map[string]map[string]*api.KeyError) - if len(req.DeviceKeys) > 0 { - a.uploadLocalDeviceKeys(ctx, req, res) - } - if len(req.OneTimeKeys) > 0 { - a.uploadOneTimeKeys(ctx, req, res) - } - otks, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) - if err != nil { - return err - } - res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks} - return nil -} - -func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error { - res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) - res.Failures = make(map[string]interface{}) - // wrap request map in a top-level by-domain map - domainToDeviceKeys := make(map[string]map[string]map[string]string) - for userID, val := range req.OneTimeKeys { - _, serverName, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - continue // ignore invalid users - } - nested, ok := domainToDeviceKeys[string(serverName)] - if !ok { - nested = make(map[string]map[string]string) - } - nested[userID] = val - domainToDeviceKeys[string(serverName)] = nested - } - for domain, local := range domainToDeviceKeys { - if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { - continue - } - // claim local keys - keys, err := a.DB.ClaimKeys(ctx, local) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err), - } - } - util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys") - for _, key := range keys { - _, ok := res.OneTimeKeys[key.UserID] - if !ok { - res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage) - } - _, ok = res.OneTimeKeys[key.UserID][key.DeviceID] - if !ok { - res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage) - } - for keyID, keyJSON := range key.KeyJSON { - res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON - } - } - delete(domainToDeviceKeys, domain) - } - if len(domainToDeviceKeys) > 0 { - a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) - } - return nil -} - -func (a *KeyInternalAPI) claimRemoteKeys( - ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, -) { - var wg sync.WaitGroup // Wait for fan-out goroutines to finish - var mu sync.Mutex // Protects the response struct - var claimed int // Number of keys claimed in total - var failures int // Number of servers we failed to ask - - util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys)) - wg.Add(len(domainToDeviceKeys)) - - for d, k := range domainToDeviceKeys { - go func(domain string, keysToClaim map[string]map[string]string) { - fedCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - defer wg.Done() - - claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) - - mu.Lock() - defer mu.Unlock() - - if err != nil { - util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed") - res.Failures[domain] = map[string]interface{}{ - "message": err.Error(), - } - failures++ - return - } - - for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys { - res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage) - for deviceID, keys := range deviceIDToKeys { - res.OneTimeKeys[userID][deviceID] = keys - claimed += len(keys) - } - } - }(d, k) - } - - wg.Wait() - util.GetLogger(ctx).WithFields(logrus.Fields{ - "num_keys": claimed, - "num_failures": failures, - }).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys)) -} - -func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error { - if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("Failed to delete device keys: %s", err), - } - } - return nil -} - -func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error { - count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("Failed to query OTK counts: %s", err), - } - return nil - } - res.Count = *count - return nil -} - -func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error { - msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to query DB for device keys: %s", err), - } - return nil - } - maxStreamID := int64(0) - // remove deleted devices - var result []api.DeviceMessage - for _, m := range msgs { - if m.StreamID > maxStreamID { - maxStreamID = m.StreamID - } - if m.KeyJSON == nil || len(m.KeyJSON) == 0 { - continue - } - result = append(result, m) - } - res.Devices = result - res.StreamID = maxStreamID - return nil -} - -// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present -// in our database. -func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error { - knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{}, true) - if err != nil { - return err - } - if len(knownDevices) == 0 { - return nil // fmt.Errorf("unknown user %s", req.UserID) - } - - for i := range knownDevices { - if knownDevices[i].DeviceID == req.DeviceID { - return nil // we already know about this device - } - } - - return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID) -} - -// nolint:gocyclo -func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error { - var respMu sync.Mutex - res.DeviceKeys = make(map[string]map[string]json.RawMessage) - res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey) - res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) - res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) - res.Failures = make(map[string]interface{}) - - // make a map from domain to device keys - domainToDeviceKeys := make(map[string]map[string][]string) - domainToCrossSigningKeys := make(map[string]map[string]struct{}) - for userID, deviceIDs := range req.UserToDevices { - _, serverName, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - continue // ignore invalid users - } - domain := string(serverName) - // query local devices - if a.Cfg.Matrix.IsLocalServerName(serverName) { - deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to query local device keys: %s", err), - } - return nil - } - - // pull out display names after we have the keys so we handle wildcards correctly - var dids []string - for _, dk := range deviceKeys { - dids = append(dids, dk.DeviceID) - } - var queryRes userapi.QueryDeviceInfosResponse - err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{ - DeviceIDs: dids, - }, &queryRes) - if err != nil { - util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") - } - - if res.DeviceKeys[userID] == nil { - res.DeviceKeys[userID] = make(map[string]json.RawMessage) - } - for _, dk := range deviceKeys { - if len(dk.KeyJSON) == 0 { - continue // don't include blank keys - } - // inject display name if known (either locally or remotely) - displayName := dk.DisplayName - if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" { - displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName - } - dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { - DisplayName string `json:"device_display_name,omitempty"` - }{displayName}) - res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON - } - } else { - domainToDeviceKeys[domain] = make(map[string][]string) - domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) - } - // work out if our cross-signing request for this user was - // satisfied, if not add them to the list of things to fetch - if _, ok := res.MasterKeys[userID]; !ok { - if _, ok := domainToCrossSigningKeys[domain]; !ok { - domainToCrossSigningKeys[domain] = make(map[string]struct{}) - } - domainToCrossSigningKeys[domain][userID] = struct{}{} - } - if _, ok := res.SelfSigningKeys[userID]; !ok { - if _, ok := domainToCrossSigningKeys[domain]; !ok { - domainToCrossSigningKeys[domain] = make(map[string]struct{}) - } - domainToCrossSigningKeys[domain][userID] = struct{}{} - } - } - - // attempt to satisfy key queries from the local database first as we should get device updates pushed to us - domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys) - if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 { - // perform key queries for remote devices - a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) - } - - // Now that we've done the potentially expensive work of asking the federation, - // try filling the cross-signing keys from the database that we know about. - a.crossSigningKeysFromDatabase(ctx, req, res) - - // Finally, append signatures that we know about - // TODO: This is horrible because we need to round-trip the signature from - // JSON, add the signatures and marshal it again, for some reason? - - for targetUserID, masterKey := range res.MasterKeys { - if masterKey.Signatures == nil { - masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - for targetKeyID := range masterKey.Keys { - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID) - if err != nil { - // Stop executing the function if the context was canceled/the deadline was exceeded, - // as we can't continue without a valid context. - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return nil - } - logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") - continue - } - if len(sigMap) == 0 { - continue - } - for sourceUserID, forSourceUser := range sigMap { - for sourceKeyID, sourceSig := range forSourceUser { - if _, ok := masterKey.Signatures[sourceUserID]; !ok { - masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig - } - } - } - } - - for targetUserID, forUserID := range res.DeviceKeys { - for targetKeyID, key := range forUserID { - sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID)) - if err != nil { - // Stop executing the function if the context was canceled/the deadline was exceeded, - // as we can't continue without a valid context. - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return nil - } - logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed") - continue - } - if len(sigMap) == 0 { - continue - } - var deviceKey gomatrixserverlib.DeviceKeys - if err = json.Unmarshal(key, &deviceKey); err != nil { - continue - } - if deviceKey.Signatures == nil { - deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - for sourceUserID, forSourceUser := range sigMap { - for sourceKeyID, sourceSig := range forSourceUser { - if _, ok := deviceKey.Signatures[sourceUserID]; !ok { - deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig - } - } - if js, err := json.Marshal(deviceKey); err == nil { - res.DeviceKeys[targetUserID][targetKeyID] = js - } - } - } - return nil -} - -func (a *KeyInternalAPI) remoteKeysFromDatabase( - ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string, -) map[string]map[string][]string { - fetchRemote := make(map[string]map[string][]string) - for domain, userToDeviceMap := range domainToDeviceKeys { - for userID, deviceIDs := range userToDeviceMap { - // we can't safely return keys from the db when all devices are requested as we don't - // know if one has just been added. - if len(deviceIDs) > 0 { - err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs) - if err == nil { - continue - } - util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase") - } - // fetch device lists from remote - if _, ok := fetchRemote[domain]; !ok { - fetchRemote[domain] = make(map[string][]string) - } - fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...) - - } - } - return fetchRemote -} - -func (a *KeyInternalAPI) queryRemoteKeys( - ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, - domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{}, -) { - resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys)) - // allows us to wait until all federation servers have been poked - var wg sync.WaitGroup - // mutex for writing directly to res (e.g failures) - var respMu sync.Mutex - - domains := map[string]struct{}{} - for domain := range domainToDeviceKeys { - if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { - continue - } - domains[domain] = struct{}{} - } - for domain := range domainToCrossSigningKeys { - if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { - continue - } - domains[domain] = struct{}{} - } - wg.Add(len(domains)) - - // fan out - for domain := range domains { - go a.queryRemoteKeysOnServer( - ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain], - &wg, &respMu, timeout, resultCh, res, - ) - } - - // Close the result channel when the goroutines have quit so the for .. range exits - go func() { - wg.Wait() - close(resultCh) - }() - - processResult := func(result *gomatrixserverlib.RespQueryKeys) { - respMu.Lock() - defer respMu.Unlock() - for userID, nest := range result.DeviceKeys { - res.DeviceKeys[userID] = make(map[string]json.RawMessage) - for deviceID, deviceKey := range nest { - keyJSON, err := json.Marshal(deviceKey) - if err != nil { - continue - } - res.DeviceKeys[userID][deviceID] = keyJSON - } - } - - for userID, body := range result.MasterKeys { - res.MasterKeys[userID] = body - } - - for userID, body := range result.SelfSigningKeys { - res.SelfSigningKeys[userID] = body - } - - // TODO: do we want to persist these somewhere now - // that we have fetched them? - } - - for result := range resultCh { - processResult(result) - } -} - -func (a *KeyInternalAPI) queryRemoteKeysOnServer( - ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{}, - wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys, - res *api.QueryKeysResponse, -) { - defer wg.Done() - fedCtx := ctx - if timeout > 0 { - var cancel context.CancelFunc - fedCtx, cancel = context.WithTimeout(ctx, timeout) - defer cancel() - } - // for users who we do not have any knowledge about, try to start doing device list updates for them - // by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but - // lack a stream ID. - userIDsForAllDevices := map[string]struct{}{} - for userID, deviceIDs := range devKeys { - if len(deviceIDs) == 0 { - userIDsForAllDevices[userID] = struct{}{} - } - } - // for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing - // a device list update, so we'll populate those back into the /keys/query list if not - for userID := range crossSigningKeys { - if devKeys == nil { - devKeys = map[string][]string{} - } - if _, ok := userIDsForAllDevices[userID]; !ok { - devKeys[userID] = []string{} - } - } - for userID := range userIDsForAllDevices { - err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID) - if err != nil { - logrus.WithFields(logrus.Fields{ - logrus.ErrorKey: err, - "user_id": userID, - "server": serverName, - }).Error("Failed to manually update device lists for user") - // try to do it via /keys/query - devKeys[userID] = []string{} - continue - } - // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this - // user so the fact that we're populating all devices here isn't a problem so long as we have devices. - err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil) - if err != nil { - logrus.WithFields(logrus.Fields{ - logrus.ErrorKey: err, - "user_id": userID, - "server": serverName, - }).Error("Failed to manually update device lists for user") - // try to do it via /keys/query - devKeys[userID] = []string{} - continue - } - } - if len(devKeys) == 0 { - return - } - queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) - if err == nil { - resultCh <- &queryKeysResp - return - } - respMu.Lock() - res.Failures[serverName] = map[string]interface{}{ - "message": err.Error(), - } - respMu.Unlock() - - // last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server - // is down, better to return something than nothing at all. Clients can know about the failure by - // inspecting the failures map though so they can know it's a cached response. - for userID, dkeys := range devKeys { - // drop the error as it's already a failure at this point - _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys) - } - - // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache - respMu.Lock() - if len(res.DeviceKeys) > 0 { - delete(res.Failures, serverName) - } - respMu.Unlock() -} - -func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( - ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string, -) error { - keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) - // if we can't query the db or there are fewer keys than requested, fetch from remote. - if err != nil { - return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err) - } - if len(keys) < len(deviceIDs) { - return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID) - } - if len(deviceIDs) == 0 && len(keys) == 0 { - return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID) - } - respMu.Lock() - if res.DeviceKeys[userID] == nil { - res.DeviceKeys[userID] = make(map[string]json.RawMessage) - } - respMu.Unlock() - - for _, key := range keys { - if len(key.KeyJSON) == 0 { - continue // ignore deleted keys - } - // inject the display name - key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { - DisplayName string `json:"device_display_name,omitempty"` - }{key.DisplayName}) - respMu.Lock() - res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON - respMu.Unlock() - } - return nil -} - -func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { - // get a list of devices from the user API that actually exist, as - // we won't store keys for devices that don't exist - uapidevices := &userapi.QueryDevicesResponse{} - if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil { - res.Error = &api.KeyError{ - Err: err.Error(), - } - return - } - if !uapidevices.UserExists { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("user %q does not exist", req.UserID), - } - return - } - existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices)) - for _, key := range uapidevices.Devices { - existingDeviceMap[key.ID] = struct{}{} - } - - // Get all of the user existing device keys so we can check for changes. - existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()), - } - return - } - - // Work out whether we have device keys in the keyserver for devices that - // no longer exist in the user API. This is mostly an exercise to ensure - // that we keep some integrity between the two. - var toClean []gomatrixserverlib.KeyID - for _, k := range existingKeys { - if _, ok := existingDeviceMap[k.DeviceID]; !ok { - toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID)) - } - } - - if len(toClean) > 0 { - if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil { - logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean)) - } else { - logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean)) - } - } - - var keysToStore []api.DeviceMessage - - if req.OnlyDisplayNameUpdates { - for _, existingKey := range existingKeys { - for _, newKey := range req.DeviceKeys { - switch { - case existingKey.UserID != newKey.UserID: - continue - case existingKey.DeviceID != newKey.DeviceID: - continue - case existingKey.DisplayName != newKey.DisplayName: - existingKey.DisplayName = newKey.DisplayName - } - } - keysToStore = append(keysToStore, existingKey) - } - } else { - // assert that the user ID / device ID are not lying for each key - for _, key := range req.DeviceKeys { - var serverName gomatrixserverlib.ServerName - _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID) - if err != nil { - continue // ignore invalid users - } - if !a.Cfg.Matrix.IsLocalServerName(serverName) { - continue // ignore remote users - } - if len(key.KeyJSON) == 0 { - keysToStore = append(keysToStore, key.WithStreamID(0)) - continue // deleted keys don't need sanity checking - } - // check that the device in question actually exists in the user - // API before we try and store a key for it - if _, ok := existingDeviceMap[key.DeviceID]; !ok { - continue - } - gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str - gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str - if gotUserID == key.UserID && gotDeviceID == key.DeviceID { - keysToStore = append(keysToStore, key.WithStreamID(0)) - continue - } - - res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ - Err: fmt.Sprintf( - "user_id or device_id mismatch: users: %s - %s, devices: %s - %s", - gotUserID, key.UserID, gotDeviceID, key.DeviceID, - ), - }) - } - } - - // store the device keys and emit changes - err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), - } - return - } - err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates) - if err != nil { - util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) - } -} - -func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { - if req.UserID == "" { - res.Error = &api.KeyError{ - Err: "user ID missing", - } - } - if req.DeviceID != "" && len(req.OneTimeKeys) == 0 { - counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err), - } - } - if counts != nil { - res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts) - } - return - } - for _, key := range req.OneTimeKeys { - // grab existing keys based on (user/device/algorithm/key ID) - keyIDsWithAlgorithms := make([]string, len(key.KeyJSON)) - i := 0 - for keyIDWithAlgo := range key.KeyJSON { - keyIDsWithAlgorithms[i] = keyIDWithAlgo - i++ - } - existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms) - if err != nil { - res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ - Err: "failed to query existing one-time keys: " + err.Error(), - }) - continue - } - for keyIDWithAlgo := range existingKeys { - // if keys exist and the JSON doesn't match, error out as the key already exists - if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) { - res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ - Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo), - }) - continue - } - } - // store one-time keys - counts, err := a.DB.StoreOneTimeKeys(ctx, key) - if err != nil { - res.KeyError(req.UserID, req.DeviceID, &api.KeyError{ - Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()), - }) - continue - } - // collect counts - res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts) - } - -} - -func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error { - // if we only want to update the display names, we can skip the checks below - if onlyUpdateDisplayName { - return producer.ProduceKeyChanges(new) - } - // find keys in new that are not in existing - var keysAdded []api.DeviceMessage - for _, newKey := range new { - exists := false - for _, existingKey := range existing { - // Do not treat the absence of keys as equal, or else we will not emit key changes - // when users delete devices which never had a key to begin with as both KeyJSONs are nil. - if existingKey.DeviceKeysEqual(&newKey) { - exists = true - break - } - } - if !exists { - keysAdded = append(keysAdded, newKey) - } - } - return producer.ProduceKeyChanges(keysAdded) -} diff --git a/keyserver/internal/internal_test.go b/keyserver/internal/internal_test.go deleted file mode 100644 index 8a2c9c5d..00000000 --- a/keyserver/internal/internal_test.go +++ /dev/null @@ -1,156 +0,0 @@ -package internal_test - -import ( - "context" - "reflect" - "testing" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/test" -) - -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { - t.Helper() - connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewDatabase(nil, &config.DatabaseOptions{ - ConnectionString: config.DataSource(connStr), - }) - if err != nil { - t.Fatalf("failed to create new user db: %v", err) - } - return db, close -} - -func Test_QueryDeviceMessages(t *testing.T) { - alice := test.NewUser(t) - type args struct { - req *api.QueryDeviceMessagesRequest - res *api.QueryDeviceMessagesResponse - } - tests := []struct { - name string - args args - wantErr bool - want *api.QueryDeviceMessagesResponse - }{ - { - name: "no existing keys", - args: args{ - req: &api.QueryDeviceMessagesRequest{ - UserID: "@doesNotExist:localhost", - }, - res: &api.QueryDeviceMessagesResponse{}, - }, - want: &api.QueryDeviceMessagesResponse{}, - }, - { - name: "existing user returns devices", - args: args{ - req: &api.QueryDeviceMessagesRequest{ - UserID: alice.ID, - }, - res: &api.QueryDeviceMessagesResponse{}, - }, - want: &api.QueryDeviceMessagesResponse{ - StreamID: 6, - Devices: []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, StreamID: 5, DeviceKeys: &api.DeviceKeys{ - DeviceID: "myDevice", - DisplayName: "first device", - UserID: alice.ID, - KeyJSON: []byte("ghi"), - }, - }, - { - Type: api.TypeDeviceKeyUpdate, StreamID: 6, DeviceKeys: &api.DeviceKeys{ - DeviceID: "mySecondDevice", - DisplayName: "second device", - UserID: alice.ID, - KeyJSON: []byte("jkl"), - }, // streamID 6 - }, - }, - }, - }, - } - - deviceMessages := []api.DeviceMessage{ - { // not the user we're looking for - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - UserID: "@doesNotExist:localhost", - }, - // streamID 1 for this user - }, - { // empty keyJSON will be ignored - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - DeviceID: "myDevice", - UserID: alice.ID, - }, // streamID 1 - }, - { - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - DeviceID: "myDevice", - UserID: alice.ID, - KeyJSON: []byte("abc"), - }, // streamID 2 - }, - { - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - DeviceID: "myDevice", - UserID: alice.ID, - KeyJSON: []byte("def"), - }, // streamID 3 - }, - { - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - DeviceID: "myDevice", - UserID: alice.ID, - KeyJSON: []byte(""), - }, // streamID 4 - }, - { - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - DeviceID: "myDevice", - DisplayName: "first device", - UserID: alice.ID, - KeyJSON: []byte("ghi"), - }, // streamID 5 - }, - { - Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{ - DeviceID: "mySecondDevice", - UserID: alice.ID, - KeyJSON: []byte("jkl"), - DisplayName: "second device", - }, // streamID 6 - }, - } - ctx := context.Background() - - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, closeDB := mustCreateDatabase(t, dbType) - defer closeDB() - if err := db.StoreLocalDeviceKeys(ctx, deviceMessages); err != nil { - t.Fatalf("failed to store local devicesKeys") - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a := &internal.KeyInternalAPI{ - DB: db, - } - if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr { - t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr) - } - got := tt.args.res - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("QueryDeviceMessages(): got:\n%+v, want:\n%+v", got, tt.want) - } - }) - } - }) -} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go deleted file mode 100644 index 2d143682..00000000 --- a/keyserver/keyserver.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package keyserver - -import ( - "github.com/sirupsen/logrus" - - rsapi "github.com/matrix-org/dendrite/roomserver/api" - - fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/consumers" - "github.com/matrix-org/dendrite/keyserver/internal" - "github.com/matrix-org/dendrite/keyserver/producers" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" -) - -// NewInternalAPI returns a concerete implementation of the internal API. Callers -// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. -func NewInternalAPI( - base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI, - rsAPI rsapi.KeyserverRoomserverAPI, -) api.KeyInternalAPI { - js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) - - db, err := storage.NewDatabase(base, &cfg.Database) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to key server database") - } - - keyChangeProducer := &producers.KeyChange{ - Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)), - JetStream: js, - DB: db, - } - ap := &internal.KeyInternalAPI{ - DB: db, - Cfg: cfg, - FedClient: fedClient, - Producer: keyChangeProducer, - } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable - ap.Updater = updater - - // Remove users which we don't share a room with anymore - if err := updater.CleanUp(); err != nil { - logrus.WithError(err).Error("failed to cleanup stale device lists") - } - - go func() { - if err := updater.Start(); err != nil { - logrus.WithError(err).Panicf("failed to start device list updater") - } - }() - - dlConsumer := consumers.NewDeviceListUpdateConsumer( - base.ProcessContext, cfg, js, updater, - ) - if err := dlConsumer.Start(); err != nil { - logrus.WithError(err).Panic("failed to start device list consumer") - } - - sigConsumer := consumers.NewSigningKeyUpdateConsumer( - base.ProcessContext, cfg, js, ap, - ) - if err := sigConsumer.Start(); err != nil { - logrus.WithError(err).Panic("failed to start signing key consumer") - } - - return ap -} diff --git a/keyserver/keyserver_test.go b/keyserver/keyserver_test.go deleted file mode 100644 index 159b280f..00000000 --- a/keyserver/keyserver_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package keyserver - -import ( - "context" - "testing" - - roomserver "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" -) - -type mockKeyserverRoomserverAPI struct { - leftUsers []string -} - -func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { - res.LeftUsers = m.leftUsers - return nil -} - -// Merely tests that we can create an internal keyserver API -func Test_NewInternalAPI(t *testing.T) { - rsAPI := &mockKeyserverRoomserverAPI{} - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, closeBase := testrig.CreateBaseDendrite(t, dbType) - defer closeBase() - _ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) - }) -} diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go deleted file mode 100644 index f86c3417..00000000 --- a/keyserver/producers/keychange.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package producers - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" -) - -// KeyChange produces key change events for the sync API and federation sender to consume -type KeyChange struct { - Topic string - JetStream nats.JetStreamContext - DB storage.Database -} - -// ProduceKeyChanges creates new change events for each key -func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { - userToDeviceCount := make(map[string]int) - for _, key := range keys { - id, err := p.DB.StoreKeyChange(context.Background(), key.UserID) - if err != nil { - return err - } - key.DeviceChangeID = id - value, err := json.Marshal(key) - if err != nil { - return err - } - - m := &nats.Msg{ - Subject: p.Topic, - Header: nats.Header{}, - } - m.Header.Set(jetstream.UserID, key.UserID) - m.Data = value - - _, err = p.JetStream.PublishMsg(m) - if err != nil { - return err - } - - userToDeviceCount[key.UserID]++ - } - for userID, count := range userToDeviceCount { - logrus.WithFields(logrus.Fields{ - "user_id": userID, - "num_key_changes": count, - }).Tracef("Produced to key change topic '%s'", p.Topic) - } - return nil -} - -func (p *KeyChange) ProduceSigningKeyUpdate(key api.CrossSigningKeyUpdate) error { - output := &api.DeviceMessage{ - Type: api.TypeCrossSigningUpdate, - OutputCrossSigningKeyUpdate: &api.OutputCrossSigningKeyUpdate{ - CrossSigningKeyUpdate: key, - }, - } - - id, err := p.DB.StoreKeyChange(context.Background(), key.UserID) - if err != nil { - return err - } - output.DeviceChangeID = id - - value, err := json.Marshal(output) - if err != nil { - return err - } - - m := &nats.Msg{ - Subject: p.Topic, - Header: nats.Header{}, - } - m.Header.Set(jetstream.UserID, key.UserID) - m.Data = value - - _, err = p.JetStream.PublishMsg(m) - if err != nil { - return err - } - - logrus.WithFields(logrus.Fields{ - "user_id": key.UserID, - }).Tracef("Produced to cross-signing update topic '%s'", p.Topic) - return nil -} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go deleted file mode 100644 index c6a8f44c..00000000 --- a/keyserver/storage/interface.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type Database interface { - // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination - // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. - ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) - - // StoreOneTimeKeys persists the given one-time keys. - StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) - - // OneTimeKeysCount returns a count of all OTKs for this device. - OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - - // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. - DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - - // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). - // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. - // Returns an error if there was a problem storing the keys. - StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error - - // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key - // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior - // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. - StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error - - // PrevIDsExists returns true if all prev IDs exist for this user. - PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) - - // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. - // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. - DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) - - // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying - // cross-signing signatures relating to that device. - DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error - - // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key - // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. - ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) - - // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. - // `userID` is the the user who has changed their keys in some way. - StoreKeyChange(ctx context.Context, userID string) (int64, error) - - // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). - // A to offset of types.OffsetNewest means no upper limit. - // Returns the offset of the latest key change. - KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) - - // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. - // If no domains are given, all user IDs with stale device lists are returned. - StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) - - // MarkDeviceListStale sets the stale bit for this user to isStale. - MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error - - CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) - CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) - CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) - - StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error - StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error - - DeleteStaleDeviceLists( - ctx context.Context, - userIDs []string, - ) error -} diff --git a/keyserver/storage/postgres/cross_signing_keys_table.go b/keyserver/storage/postgres/cross_signing_keys_table.go deleted file mode 100644 index 1022157e..00000000 --- a/keyserver/storage/postgres/cross_signing_keys_table.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "fmt" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -var crossSigningKeysSchema = ` -CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( - user_id TEXT NOT NULL, - key_type SMALLINT NOT NULL, - key_data TEXT NOT NULL, - PRIMARY KEY (user_id, key_type) -); -` - -const selectCrossSigningKeysForUserSQL = "" + - "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + - " WHERE user_id = $1" - -const upsertCrossSigningKeysForUserSQL = "" + - "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + - " VALUES($1, $2, $3)" + - " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3" - -type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt -} - -func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { - s := &crossSigningKeysStatements{ - db: db, - } - _, err := db.Exec(crossSigningKeysSchema) - if err != nil { - return nil, err - } - return s, sqlutil.StatementList{ - {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, - {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, - }.Prepare(db) -} - -func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, -) (r types.CrossSigningKeyMap, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") - r = types.CrossSigningKeyMap{} - for rows.Next() { - var keyTypeInt int16 - var keyData gomatrixserverlib.Base64Bytes - if err := rows.Scan(&keyTypeInt, &keyData); err != nil { - return nil, err - } - keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] - if !ok { - return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) - } - r[keyType] = keyData - } - return -} - -func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, -) error { - keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] - if !ok { - return fmt.Errorf("unknown key purpose %q", keyType) - } - if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { - return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) - } - return nil -} diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/keyserver/storage/postgres/cross_signing_sigs_table.go deleted file mode 100644 index 4536b7d8..00000000 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "fmt" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -var crossSigningSigsSchema = ` -CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( - origin_user_id TEXT NOT NULL, - origin_key_id TEXT NOT NULL, - target_user_id TEXT NOT NULL, - target_key_id TEXT NOT NULL, - signature TEXT NOT NULL, - PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) -); - -CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); -` - -const selectCrossSigningSigsForTargetSQL = "" + - "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + - " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3" - -const upsertCrossSigningSigsForTargetSQL = "" + - "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + - " VALUES($1, $2, $3, $4, $5)" + - " ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5" - -const deleteCrossSigningSigsForTargetSQL = "" + - "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" - -type crossSigningSigsStatements struct { - db *sql.DB - selectCrossSigningSigsForTargetStmt *sql.Stmt - upsertCrossSigningSigsForTargetStmt *sql.Stmt - deleteCrossSigningSigsForTargetStmt *sql.Stmt -} - -func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { - s := &crossSigningSigsStatements{ - db: db, - } - _, err := db.Exec(crossSigningSigsSchema) - if err != nil { - return nil, err - } - - m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "keyserver: cross signing signature indexes", - Up: deltas.UpFixCrossSigningSignatureIndexes, - }) - if err = m.Up(context.Background()); err != nil { - return nil, err - } - - return s, sqlutil.StatementList{ - {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, - {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, - {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL}, - }.Prepare(db) -} - -func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( - ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, -) (r types.CrossSigningSigMap, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed") - r = types.CrossSigningSigMap{} - for rows.Next() { - var userID string - var keyID gomatrixserverlib.KeyID - var signature gomatrixserverlib.Base64Bytes - if err := rows.Scan(&userID, &keyID, &signature); err != nil { - return nil, err - } - if _, ok := r[userID]; !ok { - r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - r[userID][keyID] = signature - } - return -} - -func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( - ctx context.Context, txn *sql.Tx, - originUserID string, originKeyID gomatrixserverlib.KeyID, - targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, -) error { - if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { - return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) - } - return nil -} - -func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( - ctx context.Context, txn *sql.Tx, - targetUserID string, targetKeyID gomatrixserverlib.KeyID, -) error { - if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { - return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) - } - return nil -} diff --git a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go b/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go deleted file mode 100644 index 0cfe9e79..00000000 --- a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package deltas - -import ( - "context" - "database/sql" - "fmt" -) - -func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { - // start counting from the last max offset, else 0. We need to do a count(*) first to see if there - // even are entries in this table to know if we can query for log_offset. Without the count then - // the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't - // exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/ - var count int - _ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count) - if count > 0 { - var maxOffset int64 - _ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset) - if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil { - return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err) - } - } - - _, err := tx.ExecContext(ctx, ` - -- make the new table - DROP TABLE IF EXISTS keyserver_key_changes; - CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), - user_id TEXT NOT NULL, - CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) - ); - `) - if err != nil { - return fmt.Errorf("failed to execute upgrade: %w", err) - } - return nil -} - -func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, ` - -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers - DROP SEQUENCE IF EXISTS keyserver_key_changes_seq; - DROP TABLE IF EXISTS keyserver_key_changes; - CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - partition BIGINT NOT NULL, - log_offset BIGINT NOT NULL, - user_id TEXT NOT NULL, - CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset) - ); - `) - if err != nil { - return fmt.Errorf("failed to execute downgrade: %w", err) - } - return nil -} diff --git a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go deleted file mode 100644 index 1a3d4fee..00000000 --- a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package deltas - -import ( - "context" - "database/sql" - "fmt" -) - -func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, ` - ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; - ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id); - - CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); - `) - if err != nil { - return fmt.Errorf("failed to execute upgrade: %w", err) - } - return nil -} - -func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, ` - ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; - ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id); - - DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx; - `) - if err != nil { - return fmt.Errorf("failed to execute downgrade: %w", err) - } - return nil -} diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go deleted file mode 100644 index 2aa11c52..00000000 --- a/keyserver/storage/postgres/device_keys_table.go +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "time" - - "github.com/lib/pq" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" -) - -var deviceKeysSchema = ` --- Stores device keys for users -CREATE TABLE IF NOT EXISTS keyserver_device_keys ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - ts_added_secs BIGINT NOT NULL, - key_json TEXT NOT NULL, - -- the stream ID of this key, scoped per-user. This gets updated when the device key changes. - -- This means we do not store an unbounded append-only log of device keys, which is not actually - -- required in the spec because in the event of a missed update the server fetches the entire - -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. - stream_id BIGINT NOT NULL, - display_name TEXT, - -- Clobber based on tuple of user/device. - CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) -); -` - -const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + - " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + - " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" - -const selectDeviceKeysSQL = "" + - "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" - -const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" - -const selectBatchDeviceKeysWithEmptiesSQL = "" + - "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" - -const selectMaxStreamForUserSQL = "" + - "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" - -const countStreamIDsForUserSQL = "" + - "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)" - -const deleteDeviceKeysSQL = "" + - "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" - -const deleteAllDeviceKeysSQL = "" + - "DELETE FROM keyserver_device_keys WHERE user_id=$1" - -type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt - selectMaxStreamForUserStmt *sql.Stmt - countStreamIDsForUserStmt *sql.Stmt - deleteDeviceKeysStmt *sql.Stmt - deleteAllDeviceKeysStmt *sql.Stmt -} - -func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { - s := &deviceKeysStatements{ - db: db, - } - _, err := db.Exec(deviceKeysSchema) - if err != nil { - return nil, err - } - if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { - return nil, err - } - if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { - return nil, err - } - if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil { - return nil, err - } - if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil { - return nil, err - } - if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { - return nil, err - } - return s, nil -} - -func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { - for i, key := range keys { - var keyJSONStr string - var streamID int64 - var displayName sql.NullString - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) - if err != nil && err != sql.ErrNoRows { - return err - } - // this will be '' when there is no device - keys[i].Type = api.TypeDeviceKeyUpdate - keys[i].KeyJSON = []byte(keyJSONStr) - keys[i].StreamID = streamID - if displayName.Valid { - keys[i].DisplayName = displayName.String - } - } - return nil -} - -func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { - // nullable if there are no results - var nullStream sql.NullInt64 - err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) - if err == sql.ErrNoRows { - err = nil - } - if nullStream.Valid { - streamID = nullStream.Int64 - } - return -} - -func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { - // nullable if there are no results - var count sql.NullInt32 - err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count) - if err != nil { - return 0, err - } - if count.Valid { - return int(count.Int32), nil - } - return 0, nil -} - -func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { - for _, key := range keys { - now := time.Now().Unix() - _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, - ) - if err != nil { - return err - } - } - return nil -} - -func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) - return err -} - -func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) - return err -} - -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { - var stmt *sql.Stmt - if includeEmpty { - stmt = s.selectBatchDeviceKeysWithEmptiesStmt - } else { - stmt = s.selectBatchDeviceKeysStmt - } - rows, err := stmt.QueryContext(ctx, userID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") - deviceIDMap := make(map[string]bool) - for _, d := range deviceIDs { - deviceIDMap[d] = true - } - var result []api.DeviceMessage - var displayName sql.NullString - for rows.Next() { - dk := api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - UserID: userID, - }, - } - if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil { - return nil, err - } - if displayName.Valid { - dk.DisplayName = displayName.String - } - // include the key if we want all keys (no device) or it was asked - if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { - result = append(result, dk) - } - } - return result, rows.Err() -} diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go deleted file mode 100644 index c0e3429c..00000000 --- a/keyserver/storage/postgres/key_changes_table.go +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "errors" - "fmt" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" -) - -var keyChangesSchema = ` --- Stores key change information about users. Used to determine when to send updated device lists to clients. -CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq; -CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), - user_id TEXT NOT NULL, - CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) -); -` - -// Replace based on user ID. We don't care how many times the user's keys have changed, only that they -// have changed, hence we can just keep bumping the change ID for this user. -const upsertKeyChangeSQL = "" + - "INSERT INTO keyserver_key_changes (user_id)" + - " VALUES ($1)" + - " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" + - " DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" + - " RETURNING change_id" - -const selectKeyChangesSQL = "" + - "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" - -type keyChangesStatements struct { - db *sql.DB - upsertKeyChangeStmt *sql.Stmt - selectKeyChangesStmt *sql.Stmt -} - -func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { - s := &keyChangesStatements{ - db: db, - } - _, err := db.Exec(keyChangesSchema) - if err != nil { - return s, err - } - - if err = executeMigration(context.Background(), db); err != nil { - return nil, err - } - return s, nil -} - -func executeMigration(ctx context.Context, db *sql.DB) error { - // TODO: Remove when we are sure we are not having goose artefacts in the db - // This forces an error, which indicates the migration is already applied, since the - // column partition was removed from the table - migrationName := "keyserver: refactor key changes" - - var cName string - err := db.QueryRowContext(ctx, "select column_name from information_schema.columns where table_name = 'keyserver_key_changes' AND column_name = 'partition'").Scan(&cName) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed - if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { - return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) - } - return nil - } - return err - } - m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: migrationName, - Up: deltas.UpRefactorKeyChanges, - }) - - return m.Up(ctx) -} - -func (s *keyChangesStatements) Prepare() (err error) { - if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { - return err - } - if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { - return err - } - return nil -} - -func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { - err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) - return -} - -func (s *keyChangesStatements) SelectKeyChanges( - ctx context.Context, fromOffset, toOffset int64, -) (userIDs []string, latestOffset int64, err error) { - latestOffset = fromOffset - rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) - if err != nil { - return nil, 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") - for rows.Next() { - var userID string - var offset int64 - if err := rows.Scan(&userID, &offset); err != nil { - return nil, 0, err - } - if offset > latestOffset { - latestOffset = offset - } - userIDs = append(userIDs, userID) - } - return -} diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go deleted file mode 100644 index 2117efca..00000000 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "encoding/json" - "time" - - "github.com/lib/pq" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" -) - -var oneTimeKeysSchema = ` --- Stores one-time public keys for users -CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - key_id TEXT NOT NULL, - algorithm TEXT NOT NULL, - ts_added_secs BIGINT NOT NULL, - key_json TEXT NOT NULL, - -- Clobber based on 4-uple of user/device/key/algorithm. - CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm) -); - -CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id); -` - -const upsertKeysSQL = "" + - "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + - " ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" + - " DO UPDATE SET key_json = $6" - -const selectKeysSQL = "" + - "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);" - -const selectKeysCountSQL = "" + - "SELECT algorithm, COUNT(key_id) FROM " + - " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" + - " x GROUP BY algorithm" - -const deleteOneTimeKeySQL = "" + - "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" - -const selectKeyByAlgorithmSQL = "" + - "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" - -const deleteOneTimeKeysSQL = "" + - "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" - -type oneTimeKeysStatements struct { - db *sql.DB - upsertKeysStmt *sql.Stmt - selectKeysStmt *sql.Stmt - selectKeysCountStmt *sql.Stmt - selectKeyByAlgorithmStmt *sql.Stmt - deleteOneTimeKeyStmt *sql.Stmt - deleteOneTimeKeysStmt *sql.Stmt -} - -func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { - s := &oneTimeKeysStatements{ - db: db, - } - _, err := db.Exec(oneTimeKeysSchema) - if err != nil { - return nil, err - } - if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil { - return nil, err - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return nil, err - } - if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { - return nil, err - } - if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { - return nil, err - } - return s, nil -} - -func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms)) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") - - result := make(map[string]json.RawMessage) - var ( - algorithmWithID string - keyJSONStr string - ) - for rows.Next() { - if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil { - return nil, err - } - result[algorithmWithID] = json.RawMessage(keyJSONStr) - } - return result, rows.Err() -} - -func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { - counts := &api.OneTimeKeysCount{ - DeviceID: deviceID, - UserID: userID, - KeyCount: make(map[string]int), - } - rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") - for rows.Next() { - var algorithm string - var count int - if err = rows.Scan(&algorithm, &count); err != nil { - return nil, err - } - counts.KeyCount[algorithm] = count - } - return counts, nil -} - -func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { - now := time.Now().Unix() - counts := &api.OneTimeKeysCount{ - DeviceID: keys.DeviceID, - UserID: keys.UserID, - KeyCount: make(map[string]int), - } - for keyIDWithAlgo, keyJSON := range keys.KeyJSON { - algo, keyID := keys.Split(keyIDWithAlgo) - _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( - ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), - ) - if err != nil { - return nil, err - } - } - rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") - for rows.Next() { - var algorithm string - var count int - if err = rows.Scan(&algorithm, &count); err != nil { - return nil, err - } - counts.KeyCount[algorithm] = count - } - - return counts, rows.Err() -} - -func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( - ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, -) (map[string]json.RawMessage, error) { - var keyID string - var keyJSON string - err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, err - } - _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) - return map[string]json.RawMessage{ - algorithm + ":" + keyID: json.RawMessage(keyJSON), - }, err -} - -func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) - return err -} diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go deleted file mode 100644 index 248ddfb4..00000000 --- a/keyserver/storage/postgres/stale_device_lists.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "time" - - "github.com/lib/pq" - - "github.com/matrix-org/dendrite/internal/sqlutil" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/gomatrixserverlib" -) - -var staleDeviceListsSchema = ` --- Stores whether a user's device lists are stale or not. -CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( - user_id TEXT PRIMARY KEY NOT NULL, - domain TEXT NOT NULL, - is_stale BOOLEAN NOT NULL, - ts_added_secs BIGINT NOT NULL -); - -CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); -` - -const upsertStaleDeviceListSQL = "" + - "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + - " VALUES ($1, $2, $3, $4)" + - " ON CONFLICT (user_id)" + - " DO UPDATE SET is_stale = $3, ts_added_secs = $4" - -const selectStaleDeviceListsWithDomainsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" - -const selectStaleDeviceListsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" - -const deleteStaleDevicesSQL = "" + - "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)" - -type staleDeviceListsStatements struct { - upsertStaleDeviceListStmt *sql.Stmt - selectStaleDeviceListsWithDomainsStmt *sql.Stmt - selectStaleDeviceListsStmt *sql.Stmt - deleteStaleDeviceListsStmt *sql.Stmt -} - -func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { - s := &staleDeviceListsStatements{} - _, err := db.Exec(staleDeviceListsSchema) - if err != nil { - return nil, err - } - return s, sqlutil.StatementList{ - {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, - {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, - {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, - {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, - }.Prepare(db) -} - -func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { - _, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return err - } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) - return err -} - -func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - // we only query for 1 domain or all domains so optimise for those use cases - if len(domains) == 0 { - rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) - if err != nil { - return nil, err - } - return rowsToUserIDs(ctx, rows) - } - var result []string - for _, domain := range domains { - rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) - if err != nil { - return nil, err - } - userIDs, err := rowsToUserIDs(ctx, rows) - if err != nil { - return nil, err - } - result = append(result, userIDs...) - } - return result, nil -} - -// DeleteStaleDeviceLists removes users from stale device lists -func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( - ctx context.Context, txn *sql.Tx, userIDs []string, -) error { - stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt) - _, err := stmt.ExecContext(ctx, pq.Array(userIDs)) - return err -} - -func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { - defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") - for rows.Next() { - var userID string - if err := rows.Scan(&userID); err != nil { - return nil, err - } - result = append(result, userID) - } - return result, rows.Err() -} diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go deleted file mode 100644 index 35e63055..00000000 --- a/keyserver/storage/postgres/storage.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/shared" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -// NewDatabase creates a new sync server database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - var err error - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) - if err != nil { - return nil, err - } - otk, err := NewPostgresOneTimeKeysTable(db) - if err != nil { - return nil, err - } - dk, err := NewPostgresDeviceKeysTable(db) - if err != nil { - return nil, err - } - kc, err := NewPostgresKeyChangesTable(db) - if err != nil { - return nil, err - } - sdl, err := NewPostgresStaleDeviceListsTable(db) - if err != nil { - return nil, err - } - csk, err := NewPostgresCrossSigningKeysTable(db) - if err != nil { - return nil, err - } - css, err := NewPostgresCrossSigningSigsTable(db) - if err != nil { - return nil, err - } - if err = kc.Prepare(); err != nil { - return nil, err - } - d := &shared.Database{ - DB: db, - Writer: writer, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, - StaleDeviceListsTable: sdl, - CrossSigningKeysTable: csk, - CrossSigningSigsTable: css, - } - return d, nil -} diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go deleted file mode 100644 index 54dd6ddc..00000000 --- a/keyserver/storage/shared/storage.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package shared - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type Database struct { - DB *sql.DB - Writer sqlutil.Writer - OneTimeKeysTable tables.OneTimeKeys - DeviceKeysTable tables.DeviceKeys - KeyChangesTable tables.KeyChanges - StaleDeviceListsTable tables.StaleDeviceLists - CrossSigningKeysTable tables.CrossSigningKeys - CrossSigningSigsTable tables.CrossSigningSigs -} - -func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) -} - -func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { - _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) - return err - }) - return -} - -func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { - return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) -} - -func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { - return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) -} - -func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { - count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) - if err != nil { - return false, err - } - return count == len(prevIDs), nil -} - -func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for _, userID := range clearUserIDs { - err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) - if err != nil { - return err - } - } - return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) - }) -} - -func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { - // work out the latest stream IDs for each user - userIDToStreamID := make(map[string]int64) - for _, k := range keys { - userIDToStreamID[k.UserID] = 0 - } - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for userID := range userIDToStreamID { - streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) - if err != nil { - return err - } - userIDToStreamID[userID] = streamID - } - // set the stream IDs for each key - for i := range keys { - k := keys[i] - userIDToStreamID[k.UserID]++ // start stream from 1 - k.StreamID = userIDToStreamID[k.UserID] - keys[i] = k - } - return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) - }) -} - -func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { - return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) -} - -func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { - var result []api.OneTimeKeys - err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for userID, deviceToAlgo := range userToDeviceToAlgorithm { - for deviceID, algo := range deviceToAlgo { - keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo) - if err != nil { - return err - } - if keyJSON != nil { - result = append(result, api.OneTimeKeys{ - UserID: userID, - DeviceID: deviceID, - KeyJSON: keyJSON, - }) - } - } - } - return nil - }) - return result, err -} - -func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { - err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) - return err - }) - return -} - -func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { - return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) -} - -// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. -// If no domains are given, all user IDs with stale device lists are returned. -func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) -} - -// MarkDeviceListStale sets the stale bit for this user to isStale. -func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { - return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) - }) -} - -// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying -// cross-signing signatures relating to that device. -func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for _, deviceID := range deviceIDs { - if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) - } - if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) - } - if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err) - } - } - return nil - }) -} - -// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { - keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) - if err != nil { - return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) - } - results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} - for purpose, key := range keyMap { - keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) - result := gomatrixserverlib.CrossSigningKey{ - UserID: userID, - Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose}, - Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ - keyID: key, - }, - } - sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID) - if err != nil { - continue - } - for sigUserID, forSigUserID := range sigMap { - if userID != sigUserID { - continue - } - if result.Signatures == nil { - result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - if _, ok := result.Signatures[sigUserID]; !ok { - result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - for sigKeyID, sigBytes := range forSigUserID { - result.Signatures[sigUserID][sigKeyID] = sigBytes - } - } - results[purpose] = result - } - return results, nil -} - -// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. -func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { - return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) -} - -// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. -func (d *Database) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { - return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) -} - -// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. -func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - for keyType, keyData := range keyMap { - if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { - return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) - } - } - return nil - }) -} - -// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. -func (d *Database) StoreCrossSigningSigsForTarget( - ctx context.Context, - originUserID string, originKeyID gomatrixserverlib.KeyID, - targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { - return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err) - } - return nil - }) -} - -// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. -func (d *Database) DeleteStaleDeviceLists( - ctx context.Context, - userIDs []string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) - }) -} diff --git a/keyserver/storage/sqlite3/cross_signing_keys_table.go b/keyserver/storage/sqlite3/cross_signing_keys_table.go deleted file mode 100644 index e103d988..00000000 --- a/keyserver/storage/sqlite3/cross_signing_keys_table.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "fmt" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -var crossSigningKeysSchema = ` -CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( - user_id TEXT NOT NULL, - key_type INTEGER NOT NULL, - key_data TEXT NOT NULL, - PRIMARY KEY (user_id, key_type) -); -` - -const selectCrossSigningKeysForUserSQL = "" + - "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + - " WHERE user_id = $1" - -const upsertCrossSigningKeysForUserSQL = "" + - "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + - " VALUES($1, $2, $3)" - -type crossSigningKeysStatements struct { - db *sql.DB - selectCrossSigningKeysForUserStmt *sql.Stmt - upsertCrossSigningKeysForUserStmt *sql.Stmt -} - -func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { - s := &crossSigningKeysStatements{ - db: db, - } - _, err := db.Exec(crossSigningKeysSchema) - if err != nil { - return nil, err - } - return s, sqlutil.StatementList{ - {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, - {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, - }.Prepare(db) -} - -func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, -) (r types.CrossSigningKeyMap, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") - r = types.CrossSigningKeyMap{} - for rows.Next() { - var keyTypeInt int16 - var keyData gomatrixserverlib.Base64Bytes - if err := rows.Scan(&keyTypeInt, &keyData); err != nil { - return nil, err - } - keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] - if !ok { - return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) - } - r[keyType] = keyData - } - return -} - -func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( - ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, -) error { - keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] - if !ok { - return fmt.Errorf("unknown key purpose %q", keyType) - } - if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { - return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) - } - return nil -} diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/keyserver/storage/sqlite3/cross_signing_sigs_table.go deleted file mode 100644 index 7a153e8f..00000000 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "fmt" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -var crossSigningSigsSchema = ` -CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( - origin_user_id TEXT NOT NULL, - origin_key_id TEXT NOT NULL, - target_user_id TEXT NOT NULL, - target_key_id TEXT NOT NULL, - signature TEXT NOT NULL, - PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) -); - -CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); -` - -const selectCrossSigningSigsForTargetSQL = "" + - "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + - " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4" - -const upsertCrossSigningSigsForTargetSQL = "" + - "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + - " VALUES($1, $2, $3, $4, $5)" - -const deleteCrossSigningSigsForTargetSQL = "" + - "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" - -type crossSigningSigsStatements struct { - db *sql.DB - selectCrossSigningSigsForTargetStmt *sql.Stmt - upsertCrossSigningSigsForTargetStmt *sql.Stmt - deleteCrossSigningSigsForTargetStmt *sql.Stmt -} - -func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { - s := &crossSigningSigsStatements{ - db: db, - } - _, err := db.Exec(crossSigningSigsSchema) - if err != nil { - return nil, err - } - m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "keyserver: cross signing signature indexes", - Up: deltas.UpFixCrossSigningSignatureIndexes, - }) - if err = m.Up(context.Background()); err != nil { - return nil, err - } - - return s, sqlutil.StatementList{ - {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, - {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, - {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL}, - }.Prepare(db) -} - -func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( - ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, -) (r types.CrossSigningSigMap, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForOriginTargetStmt: rows.close() failed") - r = types.CrossSigningSigMap{} - for rows.Next() { - var userID string - var keyID gomatrixserverlib.KeyID - var signature gomatrixserverlib.Base64Bytes - if err := rows.Scan(&userID, &keyID, &signature); err != nil { - return nil, err - } - if _, ok := r[userID]; !ok { - r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} - } - r[userID][keyID] = signature - } - return -} - -func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( - ctx context.Context, txn *sql.Tx, - originUserID string, originKeyID gomatrixserverlib.KeyID, - targetUserID string, targetKeyID gomatrixserverlib.KeyID, - signature gomatrixserverlib.Base64Bytes, -) error { - if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { - return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) - } - return nil -} - -func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( - ctx context.Context, txn *sql.Tx, - targetUserID string, targetKeyID gomatrixserverlib.KeyID, -) error { - if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { - return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) - } - return nil -} diff --git a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go b/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go deleted file mode 100644 index cd0f19df..00000000 --- a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package deltas - -import ( - "context" - "database/sql" - "fmt" -) - -func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { - // start counting from the last max offset, else 0. - var maxOffset int64 - var userID string - _ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset) - - _, err := tx.ExecContext(ctx, ` - -- make the new table - DROP TABLE IF EXISTS keyserver_key_changes; - CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - change_id INTEGER PRIMARY KEY AUTOINCREMENT, - -- The key owner - user_id TEXT NOT NULL, - UNIQUE (user_id) - ); - `) - if err != nil { - return fmt.Errorf("failed to execute upgrade: %w", err) - } - // to start counting from maxOffset, insert a row with that value - if userID != "" { - _, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID) - return err - } - return nil -} - -func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, ` - -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers - DROP TABLE IF EXISTS keyserver_key_changes; - CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - partition BIGINT NOT NULL, - offset BIGINT NOT NULL, - -- The key owner - user_id TEXT NOT NULL, - UNIQUE (partition, offset) - ); - `) - if err != nil { - return fmt.Errorf("failed to execute downgrade: %w", err) - } - return nil -} diff --git a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go deleted file mode 100644 index d4e38dea..00000000 --- a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package deltas - -import ( - "context" - "database/sql" - "fmt" -) - -func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, ` - CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( - origin_user_id TEXT NOT NULL, - origin_key_id TEXT NOT NULL, - target_user_id TEXT NOT NULL, - target_key_id TEXT NOT NULL, - signature TEXT NOT NULL, - PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) - ); - - INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature) - SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs; - - DROP TABLE keyserver_cross_signing_sigs; - ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs; - - CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); - `) - if err != nil { - return fmt.Errorf("failed to execute upgrade: %w", err) - } - return nil -} - -func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { - _, err := tx.ExecContext(ctx, ` - CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( - origin_user_id TEXT NOT NULL, - origin_key_id TEXT NOT NULL, - target_user_id TEXT NOT NULL, - target_key_id TEXT NOT NULL, - signature TEXT NOT NULL, - PRIMARY KEY (origin_user_id, target_user_id, target_key_id) - ); - - INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature) - SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs; - - DROP TABLE keyserver_cross_signing_sigs; - ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs; - - DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx; - `) - if err != nil { - return fmt.Errorf("failed to execute downgrade: %w", err) - } - return nil -} diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go deleted file mode 100644 index 73768da5..00000000 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "strings" - "time" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" -) - -var deviceKeysSchema = ` --- Stores device keys for users -CREATE TABLE IF NOT EXISTS keyserver_device_keys ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - ts_added_secs BIGINT NOT NULL, - key_json TEXT NOT NULL, - stream_id BIGINT NOT NULL, - display_name TEXT, - -- Clobber based on tuple of user/device. - UNIQUE (user_id, device_id) -); -` - -const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + - " ON CONFLICT (user_id, device_id)" + - " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" - -const selectDeviceKeysSQL = "" + - "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" - -const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" - -const selectBatchDeviceKeysWithEmptiesSQL = "" + - "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" - -const selectMaxStreamForUserSQL = "" + - "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" - -const countStreamIDsForUserSQL = "" + - "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" - -const deleteDeviceKeysSQL = "" + - "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" - -const deleteAllDeviceKeysSQL = "" + - "DELETE FROM keyserver_device_keys WHERE user_id=$1" - -type deviceKeysStatements struct { - db *sql.DB - upsertDeviceKeysStmt *sql.Stmt - selectDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysStmt *sql.Stmt - selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt - selectMaxStreamForUserStmt *sql.Stmt - deleteDeviceKeysStmt *sql.Stmt - deleteAllDeviceKeysStmt *sql.Stmt -} - -func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { - s := &deviceKeysStatements{ - db: db, - } - _, err := db.Exec(deviceKeysSchema) - if err != nil { - return nil, err - } - if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { - return nil, err - } - if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil { - return nil, err - } - if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil { - return nil, err - } - if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil { - return nil, err - } - if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil { - return nil, err - } - return s, nil -} - -func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) - return err -} - -func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) - return err -} - -func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { - deviceIDMap := make(map[string]bool) - for _, d := range deviceIDs { - deviceIDMap[d] = true - } - var stmt *sql.Stmt - if includeEmpty { - stmt = s.selectBatchDeviceKeysWithEmptiesStmt - } else { - stmt = s.selectBatchDeviceKeysStmt - } - rows, err := stmt.QueryContext(ctx, userID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") - var result []api.DeviceMessage - var displayName sql.NullString - for rows.Next() { - dk := api.DeviceMessage{ - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - UserID: userID, - }, - } - if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil { - return nil, err - } - if displayName.Valid { - dk.DisplayName = displayName.String - } - // include the key if we want all keys (no device) or it was asked - if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { - result = append(result, dk) - } - } - return result, rows.Err() -} - -func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { - for i, key := range keys { - var keyJSONStr string - var streamID int64 - var displayName sql.NullString - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) - if err != nil && err != sql.ErrNoRows { - return err - } - // this will be '' when there is no device - keys[i].Type = api.TypeDeviceKeyUpdate - keys[i].KeyJSON = []byte(keyJSONStr) - keys[i].StreamID = streamID - if displayName.Valid { - keys[i].DisplayName = displayName.String - } - } - return nil -} - -func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { - // nullable if there are no results - var nullStream sql.NullInt64 - err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) - if err == sql.ErrNoRows { - err = nil - } - if nullStream.Valid { - streamID = nullStream.Int64 - } - return -} - -func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { - iStreamIDs := make([]interface{}, len(streamIDs)+1) - iStreamIDs[0] = userID - for i := range streamIDs { - iStreamIDs[i+1] = streamIDs[i] - } - query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) - // nullable if there are no results - var count sql.NullInt64 - err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) - if err != nil { - return 0, err - } - if count.Valid { - return int(count.Int64), nil - } - return 0, nil -} - -func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { - for _, key := range keys { - now := time.Now().Unix() - _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, - ) - if err != nil { - return err - } - } - return nil -} diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go deleted file mode 100644 index 0c844d67..00000000 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "errors" - "fmt" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas" - "github.com/matrix-org/dendrite/keyserver/storage/tables" -) - -var keyChangesSchema = ` --- Stores key change information about users. Used to determine when to send updated device lists to clients. -CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - change_id INTEGER PRIMARY KEY AUTOINCREMENT, - -- The key owner - user_id TEXT NOT NULL, - UNIQUE (user_id) -); -` - -// Replace based on user ID. We don't care how many times the user's keys have changed, only that they -// have changed, hence we can just keep bumping the change ID for this user. -const upsertKeyChangeSQL = "" + - "INSERT OR REPLACE INTO keyserver_key_changes (user_id)" + - " VALUES ($1)" + - " RETURNING change_id" - -const selectKeyChangesSQL = "" + - "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" - -type keyChangesStatements struct { - db *sql.DB - upsertKeyChangeStmt *sql.Stmt - selectKeyChangesStmt *sql.Stmt -} - -func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { - s := &keyChangesStatements{ - db: db, - } - _, err := db.Exec(keyChangesSchema) - if err != nil { - return s, err - } - - if err = executeMigration(context.Background(), db); err != nil { - return nil, err - } - - return s, nil -} - -func executeMigration(ctx context.Context, db *sql.DB) error { - // TODO: Remove when we are sure we are not having goose artefacts in the db - // This forces an error, which indicates the migration is already applied, since the - // column partition was removed from the table - migrationName := "keyserver: refactor key changes" - - var cName string - err := db.QueryRowContext(ctx, `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'keyserver_key_changes' AND p.name = 'partition'`).Scan(&cName) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed - if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { - return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) - } - return nil - } - return err - } - m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: migrationName, - Up: deltas.UpRefactorKeyChanges, - }) - return m.Up(ctx) -} - -func (s *keyChangesStatements) Prepare() (err error) { - if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { - return err - } - if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { - return err - } - return nil -} - -func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { - err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) - return -} - -func (s *keyChangesStatements) SelectKeyChanges( - ctx context.Context, fromOffset, toOffset int64, -) (userIDs []string, latestOffset int64, err error) { - latestOffset = fromOffset - rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) - if err != nil { - return nil, 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") - for rows.Next() { - var userID string - var offset int64 - if err := rows.Scan(&userID, &offset); err != nil { - return nil, 0, err - } - if offset > latestOffset { - latestOffset = offset - } - userIDs = append(userIDs, userID) - } - return -} diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go deleted file mode 100644 index 7a923d0e..00000000 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "encoding/json" - "time" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage/tables" -) - -var oneTimeKeysSchema = ` --- Stores one-time public keys for users -CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( - user_id TEXT NOT NULL, - device_id TEXT NOT NULL, - key_id TEXT NOT NULL, - algorithm TEXT NOT NULL, - ts_added_secs BIGINT NOT NULL, - key_json TEXT NOT NULL, - -- Clobber based on 4-uple of user/device/key/algorithm. - UNIQUE (user_id, device_id, key_id, algorithm) -); - -CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id); -` - -const upsertKeysSQL = "" + - "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + - " VALUES ($1, $2, $3, $4, $5, $6)" + - " ON CONFLICT (user_id, device_id, key_id, algorithm)" + - " DO UPDATE SET key_json = $6" - -const selectKeysSQL = "" + - "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" - -const selectKeysCountSQL = "" + - "SELECT algorithm, COUNT(key_id) FROM " + - " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" + - " x GROUP BY algorithm" - -const deleteOneTimeKeySQL = "" + - "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" - -const selectKeyByAlgorithmSQL = "" + - "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" - -const deleteOneTimeKeysSQL = "" + - "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" - -type oneTimeKeysStatements struct { - db *sql.DB - upsertKeysStmt *sql.Stmt - selectKeysStmt *sql.Stmt - selectKeysCountStmt *sql.Stmt - selectKeyByAlgorithmStmt *sql.Stmt - deleteOneTimeKeyStmt *sql.Stmt - deleteOneTimeKeysStmt *sql.Stmt -} - -func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { - s := &oneTimeKeysStatements{ - db: db, - } - _, err := db.Exec(oneTimeKeysSchema) - if err != nil { - return nil, err - } - if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil { - return nil, err - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return nil, err - } - if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { - return nil, err - } - if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { - return nil, err - } - if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil { - return nil, err - } - return s, nil -} - -func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") - - wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) - for _, ka := range keyIDsWithAlgorithms { - wantSet[ka] = true - } - - result := make(map[string]json.RawMessage) - for rows.Next() { - var keyID string - var algorithm string - var keyJSONStr string - if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil { - return nil, err - } - keyIDWithAlgo := algorithm + ":" + keyID - if wantSet[keyIDWithAlgo] { - result[keyIDWithAlgo] = json.RawMessage(keyJSONStr) - } - } - return result, rows.Err() -} - -func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { - counts := &api.OneTimeKeysCount{ - DeviceID: deviceID, - UserID: userID, - KeyCount: make(map[string]int), - } - rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") - for rows.Next() { - var algorithm string - var count int - if err = rows.Scan(&algorithm, &count); err != nil { - return nil, err - } - counts.KeyCount[algorithm] = count - } - return counts, nil -} - -func (s *oneTimeKeysStatements) InsertOneTimeKeys( - ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys, -) (*api.OneTimeKeysCount, error) { - now := time.Now().Unix() - counts := &api.OneTimeKeysCount{ - DeviceID: keys.DeviceID, - UserID: keys.UserID, - KeyCount: make(map[string]int), - } - for keyIDWithAlgo, keyJSON := range keys.KeyJSON { - algo, keyID := keys.Split(keyIDWithAlgo) - _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( - ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), - ) - if err != nil { - return nil, err - } - } - rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") - for rows.Next() { - var algorithm string - var count int - if err = rows.Scan(&algorithm, &count); err != nil { - return nil, err - } - counts.KeyCount[algorithm] = count - } - - return counts, rows.Err() -} - -func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( - ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, -) (map[string]json.RawMessage, error) { - var keyID string - var keyJSON string - err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, err - } - _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) - if err != nil { - return nil, err - } - if keyJSON == "" { - return nil, nil - } - return map[string]json.RawMessage{ - algorithm + ":" + keyID: json.RawMessage(keyJSON), - }, err -} - -func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { - _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) - return err -} diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go deleted file mode 100644 index fd76a6e3..00000000 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "strings" - "time" - - "github.com/matrix-org/dendrite/internal/sqlutil" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/gomatrixserverlib" -) - -var staleDeviceListsSchema = ` --- Stores whether a user's device lists are stale or not. -CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( - user_id TEXT PRIMARY KEY NOT NULL, - domain TEXT NOT NULL, - is_stale BOOLEAN NOT NULL, - ts_added_secs BIGINT NOT NULL -); - -CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); -` - -const upsertStaleDeviceListSQL = "" + - "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + - " VALUES ($1, $2, $3, $4)" + - " ON CONFLICT (user_id)" + - " DO UPDATE SET is_stale = $3, ts_added_secs = $4" - -const selectStaleDeviceListsWithDomainsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" - -const selectStaleDeviceListsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" - -const deleteStaleDevicesSQL = "" + - "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)" - -type staleDeviceListsStatements struct { - db *sql.DB - upsertStaleDeviceListStmt *sql.Stmt - selectStaleDeviceListsWithDomainsStmt *sql.Stmt - selectStaleDeviceListsStmt *sql.Stmt - // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime -} - -func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { - s := &staleDeviceListsStatements{ - db: db, - } - _, err := db.Exec(staleDeviceListsSchema) - if err != nil { - return nil, err - } - return s, sqlutil.StatementList{ - {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, - {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, - {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, - // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime - }.Prepare(db) -} - -func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { - _, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return err - } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) - return err -} - -func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - // we only query for 1 domain or all domains so optimise for those use cases - if len(domains) == 0 { - rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) - if err != nil { - return nil, err - } - return rowsToUserIDs(ctx, rows) - } - var result []string - for _, domain := range domains { - rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) - if err != nil { - return nil, err - } - userIDs, err := rowsToUserIDs(ctx, rows) - if err != nil { - return nil, err - } - result = append(result, userIDs...) - } - return result, nil -} - -// DeleteStaleDeviceLists removes users from stale device lists -func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( - ctx context.Context, txn *sql.Tx, userIDs []string, -) error { - qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) - stmt, err := s.db.Prepare(qry) - if err != nil { - return err - } - defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed") - stmt = sqlutil.TxStmt(txn, stmt) - - params := make([]any, len(userIDs)) - for i := range userIDs { - params[i] = userIDs[i] - } - - _, err = stmt.ExecContext(ctx, params...) - return err -} - -func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { - defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") - for rows.Next() { - var userID string - if err := rows.Scan(&userID); err != nil { - return nil, err - } - result = append(result, userID) - } - return result, rows.Err() -} diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go deleted file mode 100644 index 873fe3e2..00000000 --- a/keyserver/storage/sqlite3/storage.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/shared" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) { - db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) - if err != nil { - return nil, err - } - otk, err := NewSqliteOneTimeKeysTable(db) - if err != nil { - return nil, err - } - dk, err := NewSqliteDeviceKeysTable(db) - if err != nil { - return nil, err - } - kc, err := NewSqliteKeyChangesTable(db) - if err != nil { - return nil, err - } - sdl, err := NewSqliteStaleDeviceListsTable(db) - if err != nil { - return nil, err - } - csk, err := NewSqliteCrossSigningKeysTable(db) - if err != nil { - return nil, err - } - css, err := NewSqliteCrossSigningSigsTable(db) - if err != nil { - return nil, err - } - - if err = kc.Prepare(); err != nil { - return nil, err - } - d := &shared.Database{ - DB: db, - Writer: writer, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, - StaleDeviceListsTable: sdl, - CrossSigningKeysTable: csk, - CrossSigningSigsTable: css, - } - return d, nil -} diff --git a/keyserver/storage/storage.go b/keyserver/storage/storage.go deleted file mode 100644 index ab6a3540..00000000 --- a/keyserver/storage/storage.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !wasm -// +build !wasm - -package storage - -import ( - "fmt" - - "github.com/matrix-org/dendrite/keyserver/storage/postgres" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets postgres connection parameters -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) - case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(base, dbProperties) - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go deleted file mode 100644 index e7a2af7c..00000000 --- a/keyserver/storage/storage_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package storage_test - -import ( - "context" - "reflect" - "sync" - "testing" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/storage" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/test/testrig" -) - -var ctx = context.Background() - -func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { - base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database) - if err != nil { - t.Fatalf("failed to create new database: %v", err) - } - return db, close -} - -func MustNotError(t *testing.T, err error) { - t.Helper() - if err == nil { - return - } - t.Fatalf("operation failed: %s", err) -} - -func TestKeyChanges(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - _, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") - MustNotError(t, err) - deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeIDC { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) - } - if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -func TestKeyChangesNoDupes(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - if deviceChangeIDA == deviceChangeIDB { - t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) - } - deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeID { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) - } - if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -func TestKeyChangesUpperLimit(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") - MustNotError(t, err) - deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") - MustNotError(t, err) - _, err = db.StoreKeyChange(ctx, "@charlie:localhost") - MustNotError(t, err) - userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) - if err != nil { - t.Fatalf("Failed to KeyChanges: %s", err) - } - if latest != deviceChangeIDB { - t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) - } - if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { - t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) - } - }) -} - -var dbLock sync.Mutex -var deviceArray = []string{"AAA", "another_device"} - -// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, -// and that they are returned correctly when querying for device keys. -func TestDeviceKeysStreamIDGeneration(t *testing.T) { - var err error - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, clean := MustCreateDatabase(t, dbType) - defer clean() - alice := "@alice:TestDeviceKeysStreamIDGeneration" - bob := "@bob:TestDeviceKeysStreamIDGeneration" - msgs := []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 1 - }, - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: bob, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 1 as this is a different user - }, - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "another_device", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), - }, - // StreamID: 2 as this is a 2nd device key - }, - } - MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) - if msgs[0].StreamID != 1 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) - } - if msgs[1].StreamID != 1 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) - } - if msgs[2].StreamID != 2 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) - } - - // updating a device sets the next stream ID for that user - msgs = []api.DeviceMessage{ - { - Type: api.TypeDeviceKeyUpdate, - DeviceKeys: &api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v2"}`), - }, - // StreamID: 3 - }, - } - MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) - if msgs[0].StreamID != 3 { - t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) - } - - dbLock.Lock() - defer dbLock.Unlock() - // Querying for device keys returns the latest stream IDs - msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false) - - if err != nil { - t.Fatalf("DeviceKeysForUser returned error: %s", err) - } - wantStreamIDs := map[string]int64{ - "AAA": 3, - "another_device": 2, - } - if len(msgs) != len(wantStreamIDs) { - t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) - } - for _, m := range msgs { - if m.StreamID != wantStreamIDs[m.DeviceID] { - t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) - } - } - }) -} diff --git a/keyserver/storage/storage_wasm.go b/keyserver/storage/storage_wasm.go deleted file mode 100644 index 75c9053e..00000000 --- a/keyserver/storage/storage_wasm.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "fmt" - - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" -) - -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) { - switch { - case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties) - case dbProperties.ConnectionString.IsPostgres(): - return nil, fmt.Errorf("can't use Postgres implementation") - default: - return nil, fmt.Errorf("unexpected database type") - } -} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go deleted file mode 100644 index 24da1125..00000000 --- a/keyserver/storage/tables/interface.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tables - -import ( - "context" - "database/sql" - "encoding/json" - - "github.com/matrix-org/dendrite/keyserver/api" - "github.com/matrix-org/dendrite/keyserver/types" - "github.com/matrix-org/gomatrixserverlib" -) - -type OneTimeKeys interface { - SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) - CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) - InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) - // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. - // Returns an empty map if the key does not exist. - SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) - DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error -} - -type DeviceKeys interface { - SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error - InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error - SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) - CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) - SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) - DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error - DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error -} - -type KeyChanges interface { - InsertKeyChange(ctx context.Context, userID string) (int64, error) - // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. - // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. - SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) - - Prepare() error -} - -type StaleDeviceLists interface { - InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error - SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) - DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error -} - -type CrossSigningKeys interface { - SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) - UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error -} - -type CrossSigningSigs interface { - SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) - UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error - DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error -} diff --git a/keyserver/storage/tables/stale_device_lists_test.go b/keyserver/storage/tables/stale_device_lists_test.go deleted file mode 100644 index 76d3badd..00000000 --- a/keyserver/storage/tables/stale_device_lists_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package tables_test - -import ( - "context" - "testing" - - "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" - "github.com/matrix-org/dendrite/setup/config" - - "github.com/matrix-org/dendrite/keyserver/storage/postgres" - "github.com/matrix-org/dendrite/keyserver/storage/tables" - "github.com/matrix-org/dendrite/test" -) - -func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) { - connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := sqlutil.Open(&config.DatabaseOptions{ - ConnectionString: config.DataSource(connStr), - }, nil) - if err != nil { - t.Fatalf("failed to open database: %s", err) - } - switch dbType { - case test.DBTypePostgres: - tab, err = postgres.NewPostgresStaleDeviceListsTable(db) - case test.DBTypeSQLite: - tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db) - } - if err != nil { - t.Fatalf("failed to create new table: %s", err) - } - return tab, close -} - -func TestStaleDeviceLists(t *testing.T) { - alice := test.NewUser(t) - bob := test.NewUser(t) - charlie := "@charlie:localhost" - ctx := context.Background() - - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - tab, closeDB := mustCreateTable(t, dbType) - defer closeDB() - - if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil { - t.Fatalf("failed to insert stale device: %s", err) - } - if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil { - t.Fatalf("failed to insert stale device: %s", err) - } - if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil { - t.Fatalf("failed to insert stale device: %s", err) - } - - // Query one server - wantStaleUsers := []string{alice.ID, bob.ID} - gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) - if err != nil { - t.Fatalf("failed to query stale device lists: %s", err) - } - if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { - t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) - } - - // Query all servers - wantStaleUsers = []string{alice.ID, bob.ID, charlie} - gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{}) - if err != nil { - t.Fatalf("failed to query stale device lists: %s", err) - } - if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { - t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) - } - - // Delete stale devices - deleteUsers := []string{alice.ID, bob.ID} - if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil { - t.Fatalf("failed to delete stale device lists: %s", err) - } - - // Verify we don't get anything back after deleting - gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) - if err != nil { - t.Fatalf("failed to query stale device lists: %s", err) - } - - if gotCount := len(gotStaleUsers); gotCount > 0 { - t.Fatalf("expected no stale users, got %d", gotCount) - } - }) -} diff --git a/keyserver/types/storage.go b/keyserver/types/storage.go deleted file mode 100644 index 7fb90454..00000000 --- a/keyserver/types/storage.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package types - -import ( - "math" - - "github.com/matrix-org/gomatrixserverlib" -) - -const ( - // OffsetNewest tells e.g. the database to get the most current data - OffsetNewest int64 = math.MaxInt64 - // OffsetOldest tells e.g. the database to get the oldest data - OffsetOldest int64 = 0 -) - -// KeyTypePurposeToInt maps a purpose to an integer, which is used in the -// database to reduce the amount of space taken up by this column. -var KeyTypePurposeToInt = map[gomatrixserverlib.CrossSigningKeyPurpose]int16{ - gomatrixserverlib.CrossSigningKeyPurposeMaster: 1, - gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: 2, - gomatrixserverlib.CrossSigningKeyPurposeUserSigning: 3, -} - -// KeyTypeIntToPurpose maps an integer to a purpose, which is used in the -// database to reduce the amount of space taken up by this column. -var KeyTypeIntToPurpose = map[int16]gomatrixserverlib.CrossSigningKeyPurpose{ - 1: gomatrixserverlib.CrossSigningKeyPurposeMaster, - 2: gomatrixserverlib.CrossSigningKeyPurposeSelfSigning, - 3: gomatrixserverlib.CrossSigningKeyPurposeUserSigning, -} - -// Map of purpose -> public key -type CrossSigningKeyMap map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.Base64Bytes - -// Map of user ID -> key ID -> signature -type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes |