aboutsummaryrefslogtreecommitdiff
path: root/keyserver
diff options
context:
space:
mode:
Diffstat (limited to 'keyserver')
-rw-r--r--keyserver/README.md19
-rw-r--r--keyserver/api/api.go346
-rw-r--r--keyserver/consumers/devicelistupdate.go95
-rw-r--r--keyserver/consumers/signingkeyupdate.go112
-rw-r--r--keyserver/internal/cross_signing.go587
-rw-r--r--keyserver/internal/device_list_update.go579
-rw-r--r--keyserver/internal/device_list_update_default.go22
-rw-r--r--keyserver/internal/device_list_update_sytest.go25
-rw-r--r--keyserver/internal/device_list_update_test.go431
-rw-r--r--keyserver/internal/internal.go816
-rw-r--r--keyserver/internal/internal_test.go156
-rw-r--r--keyserver/keyserver.go86
-rw-r--r--keyserver/keyserver_test.go29
-rw-r--r--keyserver/producers/keychange.go107
-rw-r--r--keyserver/storage/interface.go93
-rw-r--r--keyserver/storage/postgres/cross_signing_keys_table.go102
-rw-r--r--keyserver/storage/postgres/cross_signing_sigs_table.go131
-rw-r--r--keyserver/storage/postgres/deltas/2022012016470000_key_changes.go69
-rw-r--r--keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go47
-rw-r--r--keyserver/storage/postgres/device_keys_table.go228
-rw-r--r--keyserver/storage/postgres/key_changes_table.go134
-rw-r--r--keyserver/storage/postgres/one_time_keys_table.go205
-rw-r--r--keyserver/storage/postgres/stale_device_lists.go131
-rw-r--r--keyserver/storage/postgres/storage.go69
-rw-r--r--keyserver/storage/shared/storage.go261
-rw-r--r--keyserver/storage/sqlite3/cross_signing_keys_table.go101
-rw-r--r--keyserver/storage/sqlite3/cross_signing_sigs_table.go129
-rw-r--r--keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go66
-rw-r--r--keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go71
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go225
-rw-r--r--keyserver/storage/sqlite3/key_changes_table.go132
-rw-r--r--keyserver/storage/sqlite3/one_time_keys_table.go219
-rw-r--r--keyserver/storage/sqlite3/stale_device_lists.go145
-rw-r--r--keyserver/storage/sqlite3/storage.go68
-rw-r--r--keyserver/storage/storage.go40
-rw-r--r--keyserver/storage/storage_test.go197
-rw-r--r--keyserver/storage/storage_wasm.go34
-rw-r--r--keyserver/storage/tables/interface.go71
-rw-r--r--keyserver/storage/tables/stale_device_lists_test.go94
-rw-r--r--keyserver/types/storage.go50
40 files changed, 0 insertions, 6522 deletions
diff --git a/keyserver/README.md b/keyserver/README.md
deleted file mode 100644
index fd9f37d2..00000000
--- a/keyserver/README.md
+++ /dev/null
@@ -1,19 +0,0 @@
-## Key Server
-
-This is an internal component which manages E2E keys from clients. It handles all the [Key Management APIs](https://matrix.org/docs/spec/client_server/r0.6.1#key-management-api) with the exception of `/keys/changes` which is handled by Sync API. This component is designed to shard by user ID.
-
-Keys are uploaded and stored in this component, and key changes are emitted to a Kafka topic for downstream components such as Sync API.
-
-### Internal APIs
-- `PerformUploadKeys` stores identity keys and one-time public keys for given user(s).
-- `PerformClaimKeys` acquires one-time public keys for given user(s). This may involve outbound federation calls.
-- `QueryKeys` returns identity keys for given user(s). This may involve outbound federation calls. This component may then cache federated identity keys to avoid repeatedly hitting remote servers.
-- A topic which emits identity keys every time there is a change (addition or deletion).
-
-### Endpoint mappings
-- Client API maps `/keys/upload` to `PerformUploadKeys`.
-- Client API maps `/keys/query` to `QueryKeys`.
-- Client API maps `/keys/claim` to `PerformClaimKeys`.
-- Federation API maps `/user/keys/query` to `QueryKeys`.
-- Federation API maps `/user/keys/claim` to `PerformClaimKeys`.
-- Sync API maps `/keys/changes` to consuming from the Kafka topic.
diff --git a/keyserver/api/api.go b/keyserver/api/api.go
deleted file mode 100644
index 14fced3e..00000000
--- a/keyserver/api/api.go
+++ /dev/null
@@ -1,346 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package api
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "strings"
- "time"
-
- "github.com/matrix-org/gomatrixserverlib"
-
- "github.com/matrix-org/dendrite/keyserver/types"
- userapi "github.com/matrix-org/dendrite/userapi/api"
-)
-
-type KeyInternalAPI interface {
- SyncKeyAPI
- ClientKeyAPI
- FederationKeyAPI
- UserKeyAPI
-
- // SetUserAPI assigns a user API to query when extracting device names.
- SetUserAPI(i userapi.KeyserverUserAPI)
-}
-
-// API functions required by the clientapi
-type ClientKeyAPI interface {
- QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
- PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
- PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
- PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error
- // PerformClaimKeys claims one-time keys for use in pre-key messages
- PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
- PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
-}
-
-// API functions required by the userapi
-type UserKeyAPI interface {
- PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
- PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error
-}
-
-// API functions required by the syncapi
-type SyncKeyAPI interface {
- QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
- QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
- PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
-}
-
-type FederationKeyAPI interface {
- QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
- QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
- QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
- PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
- PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
-}
-
-// KeyError is returned if there was a problem performing/querying the server
-type KeyError struct {
- Err string `json:"error"`
- IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE
- IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM
- IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM
-}
-
-func (k *KeyError) Error() string {
- return k.Err
-}
-
-type DeviceMessageType int
-
-const (
- TypeDeviceKeyUpdate DeviceMessageType = iota
- TypeCrossSigningUpdate
-)
-
-// DeviceMessage represents the message produced into Kafka by the key server.
-type DeviceMessage struct {
- Type DeviceMessageType `json:"Type,omitempty"`
- *DeviceKeys `json:"DeviceKeys,omitempty"`
- *OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
- // A monotonically increasing number which represents device changes for this user.
- StreamID int64
- DeviceChangeID int64
-}
-
-// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log
-type OutputCrossSigningKeyUpdate struct {
- CrossSigningKeyUpdate `json:"signing_keys"`
-}
-
-type CrossSigningKeyUpdate struct {
- MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"`
- SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"`
- UserID string `json:"user_id"`
-}
-
-// DeviceKeysEqual returns true if the device keys updates contain the
-// same display name and key JSON. This will return false if either of
-// the updates is not a device keys update, or if the user ID/device ID
-// differ between the two.
-func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool {
- if m1.DeviceKeys == nil || m2.DeviceKeys == nil {
- return false
- }
- if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID {
- return false
- }
- if m1.DisplayName != m2.DisplayName {
- return false // different display names
- }
- if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 {
- return false // either is empty
- }
- return bytes.Equal(m1.KeyJSON, m2.KeyJSON)
-}
-
-// DeviceKeys represents a set of device keys for a single device
-// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
-type DeviceKeys struct {
- // The user who owns this device
- UserID string
- // The device ID of this device
- DeviceID string
- // The device display name
- DisplayName string
- // The raw device key JSON
- KeyJSON []byte
-}
-
-// WithStreamID returns a copy of this device message with the given stream ID
-func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage {
- return DeviceMessage{
- DeviceKeys: k,
- StreamID: streamID,
- }
-}
-
-// OneTimeKeys represents a set of one-time keys for a single device
-// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
-type OneTimeKeys struct {
- // The user who owns this device
- UserID string
- // The device ID of this device
- DeviceID string
- // A map of algorithm:key_id => key JSON
- KeyJSON map[string]json.RawMessage
-}
-
-// Split a key in KeyJSON into algorithm and key ID
-func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
- segments := strings.Split(keyIDWithAlgo, ":")
- return segments[0], segments[1]
-}
-
-// OneTimeKeysCount represents the counts of one-time keys for a single device
-type OneTimeKeysCount struct {
- // The user who owns this device
- UserID string
- // The device ID of this device
- DeviceID string
- // algorithm to count e.g:
- // {
- // "curve25519": 10,
- // "signed_curve25519": 20
- // }
- KeyCount map[string]int
-}
-
-// PerformUploadKeysRequest is the request to PerformUploadKeys
-type PerformUploadKeysRequest struct {
- UserID string // Required - User performing the request
- DeviceID string // Optional - Device performing the request, for fetching OTK count
- DeviceKeys []DeviceKeys
- OneTimeKeys []OneTimeKeys
- // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
- // the display name for their respective device, and NOT to modify the keys. The key
- // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
- // Without this flag, requests to modify device display names would delete device keys.
- OnlyDisplayNameUpdates bool
-}
-
-// PerformUploadKeysResponse is the response to PerformUploadKeys
-type PerformUploadKeysResponse struct {
- // A fatal error when processing e.g database failures
- Error *KeyError
- // A map of user_id -> device_id -> Error for tracking failures.
- KeyErrors map[string]map[string]*KeyError
- OneTimeKeyCounts []OneTimeKeysCount
-}
-
-// PerformDeleteKeysRequest asks the keyserver to forget about certain
-// keys, and signatures related to those keys.
-type PerformDeleteKeysRequest struct {
- UserID string
- KeyIDs []gomatrixserverlib.KeyID
-}
-
-// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest.
-type PerformDeleteKeysResponse struct {
- Error *KeyError
-}
-
-// KeyError sets a key error field on KeyErrors
-func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) {
- if r.KeyErrors[userID] == nil {
- r.KeyErrors[userID] = make(map[string]*KeyError)
- }
- r.KeyErrors[userID][deviceID] = err
-}
-
-type PerformClaimKeysRequest struct {
- // Map of user_id to device_id to algorithm name
- OneTimeKeys map[string]map[string]string
- Timeout time.Duration
-}
-
-type PerformClaimKeysResponse struct {
- // Map of user_id to device_id to algorithm:key_id to key JSON
- OneTimeKeys map[string]map[string]map[string]json.RawMessage
- // Map of remote server domain to error JSON
- Failures map[string]interface{}
- // Set if there was a fatal error processing this action
- Error *KeyError
-}
-
-type PerformUploadDeviceKeysRequest struct {
- gomatrixserverlib.CrossSigningKeys
- // The user that uploaded the key, should be populated by the clientapi.
- UserID string
-}
-
-type PerformUploadDeviceKeysResponse struct {
- Error *KeyError
-}
-
-type PerformUploadDeviceSignaturesRequest struct {
- Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice
- // The user that uploaded the sig, should be populated by the clientapi.
- UserID string
-}
-
-type PerformUploadDeviceSignaturesResponse struct {
- Error *KeyError
-}
-
-type QueryKeysRequest struct {
- // The user ID asking for the keys, e.g. if from a client API request.
- // Will not be populated if the key request came from federation.
- UserID string
- // Maps user IDs to a list of devices
- UserToDevices map[string][]string
- Timeout time.Duration
-}
-
-type QueryKeysResponse struct {
- // Map of remote server domain to error JSON
- Failures map[string]interface{}
- // Map of user_id to device_id to device_key
- DeviceKeys map[string]map[string]json.RawMessage
- // Maps of user_id to cross signing key
- MasterKeys map[string]gomatrixserverlib.CrossSigningKey
- SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
- UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
- // Set if there was a fatal error processing this query
- Error *KeyError
-}
-
-type QueryKeyChangesRequest struct {
- // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning
- Offset int64
- // The inclusive offset where to track key changes up to. Messages with this offset are included in the response.
- // Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing).
- ToOffset int64
-}
-
-type QueryKeyChangesResponse struct {
- // The set of users who have had their keys change.
- UserIDs []string
- // The latest offset represented in this response.
- Offset int64
- // Set if there was a problem handling the request.
- Error *KeyError
-}
-
-type QueryOneTimeKeysRequest struct {
- // The local user to query OTK counts for
- UserID string
- // The device to query OTK counts for
- DeviceID string
-}
-
-type QueryOneTimeKeysResponse struct {
- // OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
- Count OneTimeKeysCount
- Error *KeyError
-}
-
-type QueryDeviceMessagesRequest struct {
- UserID string
-}
-
-type QueryDeviceMessagesResponse struct {
- // The latest stream ID
- StreamID int64
- Devices []DeviceMessage
- Error *KeyError
-}
-
-type QuerySignaturesRequest struct {
- // A map of target user ID -> target key/device IDs to retrieve signatures for
- TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"`
-}
-
-type QuerySignaturesResponse struct {
- // A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures
- Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap
- // A map of target user ID -> cross-signing master key
- MasterKeys map[string]gomatrixserverlib.CrossSigningKey
- // A map of target user ID -> cross-signing self-signing key
- SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
- // A map of target user ID -> cross-signing user-signing key
- UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
- // The request error, if any
- Error *KeyError
-}
-
-type PerformMarkAsStaleRequest struct {
- UserID string
- Domain gomatrixserverlib.ServerName
- DeviceID string
-}
diff --git a/keyserver/consumers/devicelistupdate.go b/keyserver/consumers/devicelistupdate.go
deleted file mode 100644
index cd911f8c..00000000
--- a/keyserver/consumers/devicelistupdate.go
+++ /dev/null
@@ -1,95 +0,0 @@
-// Copyright 2022 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package consumers
-
-import (
- "context"
- "encoding/json"
-
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/nats-io/nats.go"
- "github.com/sirupsen/logrus"
-
- "github.com/matrix-org/dendrite/keyserver/internal"
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/setup/jetstream"
- "github.com/matrix-org/dendrite/setup/process"
-)
-
-// DeviceListUpdateConsumer consumes device list updates that came in over federation.
-type DeviceListUpdateConsumer struct {
- ctx context.Context
- jetstream nats.JetStreamContext
- durable string
- topic string
- updater *internal.DeviceListUpdater
- isLocalServerName func(gomatrixserverlib.ServerName) bool
-}
-
-// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers.
-func NewDeviceListUpdateConsumer(
- process *process.ProcessContext,
- cfg *config.KeyServer,
- js nats.JetStreamContext,
- updater *internal.DeviceListUpdater,
-) *DeviceListUpdateConsumer {
- return &DeviceListUpdateConsumer{
- ctx: process.Context(),
- jetstream: js,
- durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"),
- topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
- updater: updater,
- isLocalServerName: cfg.Matrix.IsLocalServerName,
- }
-}
-
-// Start consuming from key servers
-func (t *DeviceListUpdateConsumer) Start() error {
- return jetstream.JetStreamConsumer(
- t.ctx, t.jetstream, t.topic, t.durable, 1,
- t.onMessage, nats.DeliverAll(), nats.ManualAck(),
- )
-}
-
-// onMessage is called in response to a message received on the
-// key change events topic from the key server.
-func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
- msg := msgs[0] // Guaranteed to exist if onMessage is called
- var m gomatrixserverlib.DeviceListUpdateEvent
- if err := json.Unmarshal(msg.Data, &m); err != nil {
- logrus.WithError(err).Errorf("Failed to read from device list update input topic")
- return true
- }
- origin := gomatrixserverlib.ServerName(msg.Header.Get("origin"))
- if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil {
- return true
- } else if t.isLocalServerName(serverName) {
- return true
- } else if serverName != origin {
- return true
- }
-
- err := t.updater.Update(ctx, m)
- if err != nil {
- logrus.WithFields(logrus.Fields{
- "user_id": m.UserID,
- "device_id": m.DeviceID,
- "stream_id": m.StreamID,
- "prev_id": m.PrevID,
- }).WithError(err).Errorf("Failed to update device list")
- return false
- }
- return true
-}
diff --git a/keyserver/consumers/signingkeyupdate.go b/keyserver/consumers/signingkeyupdate.go
deleted file mode 100644
index bcceaad1..00000000
--- a/keyserver/consumers/signingkeyupdate.go
+++ /dev/null
@@ -1,112 +0,0 @@
-// Copyright 2022 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package consumers
-
-import (
- "context"
- "encoding/json"
-
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/nats-io/nats.go"
- "github.com/sirupsen/logrus"
-
- keyapi "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/internal"
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/setup/jetstream"
- "github.com/matrix-org/dendrite/setup/process"
-)
-
-// SigningKeyUpdateConsumer consumes signing key updates that came in over federation.
-type SigningKeyUpdateConsumer struct {
- ctx context.Context
- jetstream nats.JetStreamContext
- durable string
- topic string
- keyAPI *internal.KeyInternalAPI
- cfg *config.KeyServer
- isLocalServerName func(gomatrixserverlib.ServerName) bool
-}
-
-// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers.
-func NewSigningKeyUpdateConsumer(
- process *process.ProcessContext,
- cfg *config.KeyServer,
- js nats.JetStreamContext,
- keyAPI *internal.KeyInternalAPI,
-) *SigningKeyUpdateConsumer {
- return &SigningKeyUpdateConsumer{
- ctx: process.Context(),
- jetstream: js,
- durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"),
- topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
- keyAPI: keyAPI,
- cfg: cfg,
- isLocalServerName: cfg.Matrix.IsLocalServerName,
- }
-}
-
-// Start consuming from key servers
-func (t *SigningKeyUpdateConsumer) Start() error {
- return jetstream.JetStreamConsumer(
- t.ctx, t.jetstream, t.topic, t.durable, 1,
- t.onMessage, nats.DeliverAll(), nats.ManualAck(),
- )
-}
-
-// onMessage is called in response to a message received on the
-// signing key update events topic from the key server.
-func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
- msg := msgs[0] // Guaranteed to exist if onMessage is called
- var updatePayload keyapi.CrossSigningKeyUpdate
- if err := json.Unmarshal(msg.Data, &updatePayload); err != nil {
- logrus.WithError(err).Errorf("Failed to read from signing key update input topic")
- return true
- }
- origin := gomatrixserverlib.ServerName(msg.Header.Get("origin"))
- if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil {
- logrus.WithError(err).Error("failed to split user id")
- return true
- } else if t.isLocalServerName(serverName) {
- logrus.Warn("dropping device key update from ourself")
- return true
- } else if serverName != origin {
- logrus.Warnf("dropping device key update, %s != %s", serverName, origin)
- return true
- }
-
- keys := gomatrixserverlib.CrossSigningKeys{}
- if updatePayload.MasterKey != nil {
- keys.MasterKey = *updatePayload.MasterKey
- }
- if updatePayload.SelfSigningKey != nil {
- keys.SelfSigningKey = *updatePayload.SelfSigningKey
- }
- uploadReq := &keyapi.PerformUploadDeviceKeysRequest{
- CrossSigningKeys: keys,
- UserID: updatePayload.UserID,
- }
- uploadRes := &keyapi.PerformUploadDeviceKeysResponse{}
- if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil {
- logrus.WithError(err).Error("failed to upload device keys")
- return false
- }
- if uploadRes.Error != nil {
- logrus.WithError(uploadRes.Error).Error("failed to upload device keys")
- return true
- }
-
- return true
-}
diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go
deleted file mode 100644
index 99859dff..00000000
--- a/keyserver/internal/cross_signing.go
+++ /dev/null
@@ -1,587 +0,0 @@
-// Copyright 2021 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package internal
-
-import (
- "bytes"
- "context"
- "crypto/ed25519"
- "database/sql"
- "fmt"
- "strings"
-
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/sirupsen/logrus"
- "golang.org/x/crypto/curve25519"
-)
-
-func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpose gomatrixserverlib.CrossSigningKeyPurpose) error {
- // Is there exactly one key?
- if len(key.Keys) != 1 {
- return fmt.Errorf("should contain exactly one key")
- }
-
- // Does the key ID match the key value? Iterates exactly once
- for keyID, keyData := range key.Keys {
- b64 := keyData.Encode()
- tokens := strings.Split(string(keyID), ":")
- if len(tokens) != 2 {
- return fmt.Errorf("key ID is incorrectly formatted")
- }
- if tokens[1] != b64 {
- return fmt.Errorf("key ID isn't correct")
- }
- switch tokens[0] {
- case "ed25519":
- if len(keyData) != ed25519.PublicKeySize {
- return fmt.Errorf("ed25519 key is not the correct length")
- }
- case "curve25519":
- if len(keyData) != curve25519.PointSize {
- return fmt.Errorf("curve25519 key is not the correct length")
- }
- default:
- // We can't enforce the key length to be correct for an
- // algorithm that we don't recognise, so instead we'll
- // just make sure that it isn't incredibly excessive.
- if l := len(keyData); l > 4096 {
- return fmt.Errorf("unknown key type is too long (%d bytes)", l)
- }
- }
- }
-
- // Check to see if the signatures make sense
- for _, forOriginUser := range key.Signatures {
- for originKeyID, originSignature := range forOriginUser {
- switch strings.SplitN(string(originKeyID), ":", 1)[0] {
- case "ed25519":
- if len(originSignature) != ed25519.SignatureSize {
- return fmt.Errorf("ed25519 signature is not the correct length")
- }
- case "curve25519":
- return fmt.Errorf("curve25519 signatures are impossible")
- default:
- if l := len(originSignature); l > 4096 {
- return fmt.Errorf("unknown signature type is too long (%d bytes)", l)
- }
- }
- }
- }
-
- // Does the key claim to be from the right user?
- if userID != key.UserID {
- return fmt.Errorf("key has a user ID mismatch")
- }
-
- // Does the key contain the correct purpose?
- useful := false
- for _, usage := range key.Usage {
- if usage == purpose {
- useful = true
- break
- }
- }
- if !useful {
- return fmt.Errorf("key does not contain correct usage purpose")
- }
-
- return nil
-}
-
-// nolint:gocyclo
-func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
- // Find the keys to store.
- byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
- toStore := types.CrossSigningKeyMap{}
- hasMasterKey := false
-
- if len(req.MasterKey.Keys) > 0 {
- if err := sanityCheckKey(req.MasterKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err != nil {
- res.Error = &api.KeyError{
- Err: "Master key sanity check failed: " + err.Error(),
- IsInvalidParam: true,
- }
- return nil
- }
-
- byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey
- for _, key := range req.MasterKey.Keys { // iterates once, see sanityCheckKey
- toStore[gomatrixserverlib.CrossSigningKeyPurposeMaster] = key
- }
- hasMasterKey = true
- }
-
- if len(req.SelfSigningKey.Keys) > 0 {
- if err := sanityCheckKey(req.SelfSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err != nil {
- res.Error = &api.KeyError{
- Err: "Self-signing key sanity check failed: " + err.Error(),
- IsInvalidParam: true,
- }
- return nil
- }
-
- byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey
- for _, key := range req.SelfSigningKey.Keys { // iterates once, see sanityCheckKey
- toStore[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = key
- }
- }
-
- if len(req.UserSigningKey.Keys) > 0 {
- if err := sanityCheckKey(req.UserSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeUserSigning); err != nil {
- res.Error = &api.KeyError{
- Err: "User-signing key sanity check failed: " + err.Error(),
- IsInvalidParam: true,
- }
- return nil
- }
-
- byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey
- for _, key := range req.UserSigningKey.Keys { // iterates once, see sanityCheckKey
- toStore[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = key
- }
- }
-
- // If there's nothing to do then stop here.
- if len(toStore) == 0 {
- res.Error = &api.KeyError{
- Err: "No keys were supplied in the request",
- IsMissingParam: true,
- }
- return nil
- }
-
- // We can't have a self-signing or user-signing key without a master
- // key, so make sure we have one of those. We will also only actually do
- // something if any of the specified keys in the request are different
- // to what we've got in the database, to avoid generating key change
- // notifications unnecessarily.
- existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
- if err != nil {
- res.Error = &api.KeyError{
- Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
- }
- return nil
- }
-
- // If we still can't find a master key for the user then stop the upload.
- // This satisfies the "Fails to upload self-signing key without master key" test.
- if !hasMasterKey {
- if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey {
- res.Error = &api.KeyError{
- Err: "No master key was found",
- IsMissingParam: true,
- }
- return nil
- }
- }
-
- // Check if anything actually changed compared to what we have in the database.
- changed := false
- for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{
- gomatrixserverlib.CrossSigningKeyPurposeMaster,
- gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
- gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
- } {
- old, gotOld := existingKeys[purpose]
- new, gotNew := toStore[purpose]
- if gotOld != gotNew {
- // A new key purpose has been specified that we didn't know before,
- // or one has been removed.
- changed = true
- break
- }
- if !bytes.Equal(old, new) {
- // One of the existing keys for a purpose we already knew about has
- // changed.
- changed = true
- break
- }
- }
- if !changed {
- return nil
- }
-
- // Store the keys.
- if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err),
- }
- return nil
- }
-
- // Now upload any signatures that were included with the keys.
- for _, key := range byPurpose {
- var targetKeyID gomatrixserverlib.KeyID
- for targetKey := range key.Keys { // iterates once, see sanityCheckKey
- targetKeyID = targetKey
- }
- for sigUserID, forSigUserID := range key.Signatures {
- if sigUserID != req.UserID {
- continue
- }
- for sigKeyID, sigBytes := range forSigUserID {
- if err := a.DB.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err),
- }
- return nil
- }
- }
- }
- }
-
- // Finally, generate a notification that we updated the keys.
- update := api.CrossSigningKeyUpdate{
- UserID: req.UserID,
- }
- if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok {
- update.MasterKey = &mk
- }
- if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok {
- update.SelfSigningKey = &ssk
- }
- if update.MasterKey == nil && update.SelfSigningKey == nil {
- return nil
- }
- if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
- }
- return nil
- }
- return nil
-}
-
-func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error {
- // Before we do anything, we need the master and self-signing keys for this user.
- // Then we can verify the signatures make sense.
- queryReq := &api.QueryKeysRequest{
- UserID: req.UserID,
- UserToDevices: map[string][]string{},
- }
- queryRes := &api.QueryKeysResponse{}
- for userID := range req.Signatures {
- queryReq.UserToDevices[userID] = []string{}
- }
- _ = a.QueryKeys(ctx, queryReq, queryRes)
-
- selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
- otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
-
- // Sort signatures into two groups: one where people have signed their own
- // keys and one where people have signed someone elses
- for userID, forUserID := range req.Signatures {
- for keyID, keyOrDevice := range forUserID {
- switch key := keyOrDevice.CrossSigningBody.(type) {
- case *gomatrixserverlib.CrossSigningKey:
- if key.UserID == req.UserID {
- if _, ok := selfSignatures[userID]; !ok {
- selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
- }
- selfSignatures[userID][keyID] = keyOrDevice
- } else {
- if _, ok := otherSignatures[userID]; !ok {
- otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
- }
- otherSignatures[userID][keyID] = keyOrDevice
- }
-
- case *gomatrixserverlib.DeviceKeys:
- if key.UserID == req.UserID {
- if _, ok := selfSignatures[userID]; !ok {
- selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
- }
- selfSignatures[userID][keyID] = keyOrDevice
- } else {
- if _, ok := otherSignatures[userID]; !ok {
- otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
- }
- otherSignatures[userID][keyID] = keyOrDevice
- }
-
- default:
- continue
- }
- }
- }
-
- if err := a.processSelfSignatures(ctx, selfSignatures); err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("a.processSelfSignatures: %s", err),
- }
- return nil
- }
-
- if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("a.processOtherSignatures: %s", err),
- }
- return nil
- }
-
- // Finally, generate a notification that we updated the signatures.
- for userID := range req.Signatures {
- masterKey := queryRes.MasterKeys[userID]
- selfSigningKey := queryRes.SelfSigningKeys[userID]
- update := api.CrossSigningKeyUpdate{
- UserID: userID,
- MasterKey: &masterKey,
- SelfSigningKey: &selfSigningKey,
- }
- if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
- }
- return nil
- }
- }
- return nil
-}
-
-func (a *KeyInternalAPI) processSelfSignatures(
- ctx context.Context,
- signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
-) error {
- // Here we will process:
- // * The user signing their own devices using their self-signing key
- // * The user signing their master key using one of their devices
-
- for targetUserID, forTargetUserID := range signatures {
- for targetKeyID, signature := range forTargetUserID {
- switch sig := signature.CrossSigningBody.(type) {
- case *gomatrixserverlib.CrossSigningKey:
- for keyID := range sig.Keys {
- split := strings.SplitN(string(keyID), ":", 2)
- if len(split) > 1 && gomatrixserverlib.KeyID(split[1]) == targetKeyID {
- targetKeyID = keyID // contains the ed25519: or other scheme
- break
- }
- }
- for originUserID, forOriginUserID := range sig.Signatures {
- for originKeyID, originSig := range forOriginUserID {
- if err := a.DB.StoreCrossSigningSigsForTarget(
- ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
- ); err != nil {
- return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
- }
- }
- }
-
- case *gomatrixserverlib.DeviceKeys:
- for originUserID, forOriginUserID := range sig.Signatures {
- for originKeyID, originSig := range forOriginUserID {
- if err := a.DB.StoreCrossSigningSigsForTarget(
- ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
- ); err != nil {
- return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
- }
- }
- }
-
- default:
- return fmt.Errorf("unexpected type assertion")
- }
- }
- }
-
- return nil
-}
-
-func (a *KeyInternalAPI) processOtherSignatures(
- ctx context.Context, userID string, queryRes *api.QueryKeysResponse,
- signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
-) error {
- // Here we will process:
- // * A user signing someone else's master keys using their user-signing keys
-
- for targetUserID, forTargetUserID := range signatures {
- for _, signature := range forTargetUserID {
- switch sig := signature.CrossSigningBody.(type) {
- case *gomatrixserverlib.CrossSigningKey:
- // Find the local copy of the master key. We'll use this to be
- // sure that the supplied stanza matches the key that we think it
- // should be.
- masterKey, ok := queryRes.MasterKeys[targetUserID]
- if !ok {
- return fmt.Errorf("failed to find master key for user %q", targetUserID)
- }
-
- // For each key ID, write the signatures. Maybe there'll be more
- // than one algorithm in the future so it's best not to focus on
- // everything being ed25519:.
- for targetKeyID, suppliedKeyData := range sig.Keys {
- // The master key will be supplied in the request, but we should
- // make sure that it matches what we think the master key should
- // actually be.
- localKeyData, lok := masterKey.Keys[targetKeyID]
- if !lok {
- return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID)
- } else if !bytes.Equal(suppliedKeyData, localKeyData) {
- return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID)
- }
-
- // We only care about the signatures from the uploading user, so
- // we will ignore anything that didn't originate from them.
- userSigs, ok := sig.Signatures[userID]
- if !ok {
- return fmt.Errorf("there are no signatures on master key %q from uploading user %q", targetKeyID, userID)
- }
-
- for originKeyID, originSig := range userSigs {
- if err := a.DB.StoreCrossSigningSigsForTarget(
- ctx, userID, originKeyID, targetUserID, targetKeyID, originSig,
- ); err != nil {
- return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
- }
- }
- }
-
- default:
- // Users should only be signing another person's master key,
- // so if we're here, it's probably because it's actually a
- // gomatrixserverlib.DeviceKeys, which doesn't make sense.
- }
- }
- }
-
- return nil
-}
-
-func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
- ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse,
-) {
- for targetUserID := range req.UserToDevices {
- keys, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
- if err != nil {
- logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID)
- continue
- }
-
- for keyType, key := range keys {
- var keyID gomatrixserverlib.KeyID
- for id := range key.Keys {
- keyID = id
- break
- }
-
- sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
- if err != nil && err != sql.ErrNoRows {
- logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID)
- continue
- }
-
- appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) {
- if key.Signatures == nil {
- key.Signatures = types.CrossSigningSigMap{}
- }
- if _, ok := key.Signatures[originUserID]; !ok {
- key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes)
- }
- key.Signatures[originUserID][originKeyID] = signature
- }
-
- for originUserID, forOrigin := range sigMap {
- for originKeyID, signature := range forOrigin {
- switch {
- case req.UserID != "" && originUserID == req.UserID:
- // Include signatures that we created
- appendSignature(originUserID, originKeyID, signature)
- case originUserID == targetUserID:
- // Include signatures that were created by the person whose key
- // we are processing
- appendSignature(originUserID, originKeyID, signature)
- }
- }
- }
-
- switch keyType {
- case gomatrixserverlib.CrossSigningKeyPurposeMaster:
- res.MasterKeys[targetUserID] = key
-
- case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning:
- res.SelfSigningKeys[targetUserID] = key
-
- case gomatrixserverlib.CrossSigningKeyPurposeUserSigning:
- res.UserSigningKeys[targetUserID] = key
- }
- }
- }
-}
-
-func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error {
- for targetUserID, forTargetUser := range req.TargetIDs {
- keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
- if err != nil && err != sql.ErrNoRows {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err),
- }
- continue
- }
-
- for targetPurpose, targetKey := range keyMap {
- switch targetPurpose {
- case gomatrixserverlib.CrossSigningKeyPurposeMaster:
- if res.MasterKeys == nil {
- res.MasterKeys = map[string]gomatrixserverlib.CrossSigningKey{}
- }
- res.MasterKeys[targetUserID] = targetKey
-
- case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning:
- if res.SelfSigningKeys == nil {
- res.SelfSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{}
- }
- res.SelfSigningKeys[targetUserID] = targetKey
-
- case gomatrixserverlib.CrossSigningKeyPurposeUserSigning:
- if res.UserSigningKeys == nil {
- res.UserSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{}
- }
- res.UserSigningKeys[targetUserID] = targetKey
- }
- }
-
- for _, targetKeyID := range forTargetUser {
- // Get own signatures only.
- sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID)
- if err != nil && err != sql.ErrNoRows {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),
- }
- return nil
- }
-
- for sourceUserID, forSourceUser := range sigMap {
- for sourceKeyID, sourceSig := range forSourceUser {
- if res.Signatures == nil {
- res.Signatures = map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{}
- }
- if _, ok := res.Signatures[targetUserID]; !ok {
- res.Signatures[targetUserID] = map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{}
- }
- if _, ok := res.Signatures[targetUserID][targetKeyID]; !ok {
- res.Signatures[targetUserID][targetKeyID] = types.CrossSigningSigMap{}
- }
- if _, ok := res.Signatures[targetUserID][targetKeyID][sourceUserID]; !ok {
- res.Signatures[targetUserID][targetKeyID][sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
- }
- res.Signatures[targetUserID][targetKeyID][sourceUserID][sourceKeyID] = sourceSig
- }
- }
- }
- }
- return nil
-}
diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go
deleted file mode 100644
index 1b00d1ee..00000000
--- a/keyserver/internal/device_list_update.go
+++ /dev/null
@@ -1,579 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package internal
-
-import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "hash/fnv"
- "net"
- "sync"
- "time"
-
- rsapi "github.com/matrix-org/dendrite/roomserver/api"
-
- "github.com/matrix-org/gomatrix"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/util"
- "github.com/prometheus/client_golang/prometheus"
- "github.com/sirupsen/logrus"
-
- fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/setup/process"
-)
-
-var (
- deviceListUpdateCount = prometheus.NewCounterVec(
- prometheus.CounterOpts{
- Namespace: "dendrite",
- Subsystem: "keyserver",
- Name: "device_list_update",
- Help: "Number of times we have attempted to update device lists from this server",
- },
- []string{"server"},
- )
-)
-
-const requestTimeout = time.Second * 30
-
-func init() {
- prometheus.MustRegister(
- deviceListUpdateCount,
- )
-}
-
-// DeviceListUpdater handles device list updates from remote servers.
-//
-// In the case where we have the prev_id for an update, the updater just stores the update (after acquiring a per-user lock).
-// In the case where we do not have the prev_id for an update, the updater marks the user_id as stale and notifies
-// a worker to get the latest device list for this user. Note: stream IDs are scoped per user so missing a prev_id
-// for a (user, device) does not mean that DEVICE is outdated as the previous ID could be for a different device:
-// we have to invalidate all devices for that user. Once the list has been fetched, the per-user lock is acquired and the
-// updater stores the latest list along with the latest stream ID.
-//
-// On startup, the updater spins up N workers which are responsible for querying device keys from remote servers.
-// Workers are scoped by homeserver domain, with one worker responsible for many domains, determined by hashing
-// mod N the server name. Work is sent via a channel which just serves to "poke" the worker as the data is retrieved
-// from the database (which allows us to batch requests to the same server). This has a number of desirable properties:
-// - We guarantee only 1 in-flight /keys/query request per server at any time as there is exactly 1 worker responsible
-// for that domain.
-// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where
-// we have many many servers)
-// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers.
-//
-// The downsides are that:
-// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free
-// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts)
-// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests
-// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse
-// than being stuck behind foo.bar
-//
-// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is
-// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried.
-type DeviceListUpdater struct {
- process *process.ProcessContext
- // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1
- // request to the remote server and race.
- // TODO: Put in an LRU cache to bound growth
- userIDToMutex map[string]*sync.Mutex
- mu *sync.Mutex // protects UserIDToMutex
-
- db DeviceListUpdaterDatabase
- api DeviceListUpdaterAPI
- producer KeyChangeProducer
- fedClient fedsenderapi.KeyserverFederationAPI
- workerChans []chan gomatrixserverlib.ServerName
- thisServer gomatrixserverlib.ServerName
-
- // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
- // block on or timeout via a select.
- userIDToChan map[string]chan bool
- userIDToChanMu *sync.Mutex
- rsAPI rsapi.KeyserverRoomserverAPI
-}
-
-// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
-// Useful for testing.
-type DeviceListUpdaterDatabase interface {
- // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
- // If no domains are given, all user IDs with stale device lists are returned.
- StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
-
- // MarkDeviceListStale sets the stale bit for this user to isStale.
- MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
-
- // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
- // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
- // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
- StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
-
- // PrevIDsExists returns true if all prev IDs exist for this user.
- PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
-
- // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
- DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
-
- DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error
-}
-
-type DeviceListUpdaterAPI interface {
- PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error
-}
-
-// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
-type KeyChangeProducer interface {
- ProduceKeyChanges(keys []api.DeviceMessage) error
-}
-
-// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
-func NewDeviceListUpdater(
- process *process.ProcessContext, db DeviceListUpdaterDatabase,
- api DeviceListUpdaterAPI, producer KeyChangeProducer,
- fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
- rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName,
-) *DeviceListUpdater {
- return &DeviceListUpdater{
- process: process,
- userIDToMutex: make(map[string]*sync.Mutex),
- mu: &sync.Mutex{},
- db: db,
- api: api,
- producer: producer,
- fedClient: fedClient,
- thisServer: thisServer,
- workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
- userIDToChan: make(map[string]chan bool),
- userIDToChanMu: &sync.Mutex{},
- rsAPI: rsAPI,
- }
-}
-
-// Start the device list updater, which will try to refresh any stale device lists.
-func (u *DeviceListUpdater) Start() error {
- for i := 0; i < len(u.workerChans); i++ {
- // Allocate a small buffer per channel.
- // If the buffer limit is reached, backpressure will cause the processing of EDUs
- // to stop (in this transaction) until key requests can be made.
- ch := make(chan gomatrixserverlib.ServerName, 10)
- u.workerChans[i] = ch
- go u.worker(ch)
- }
-
- staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
- if err != nil {
- return err
- }
- offset, step := time.Second*10, time.Second
- if max := len(staleLists); max > 120 {
- step = (time.Second * 120) / time.Duration(max)
- }
- for _, userID := range staleLists {
- userID := userID // otherwise we are only sending the last entry
- time.AfterFunc(offset, func() {
- u.notifyWorkers(userID)
- })
- offset += step
- }
- return nil
-}
-
-// CleanUp removes stale device entries for users we don't share a room with anymore
-func (u *DeviceListUpdater) CleanUp() error {
- staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
- if err != nil {
- return err
- }
-
- res := rsapi.QueryLeftUsersResponse{}
- if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil {
- return err
- }
-
- if len(res.LeftUsers) == 0 {
- return nil
- }
- logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers))
- return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers)
-}
-
-func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
- u.mu.Lock()
- defer u.mu.Unlock()
- if u.userIDToMutex[userID] == nil {
- u.userIDToMutex[userID] = &sync.Mutex{}
- }
- return u.userIDToMutex[userID]
-}
-
-// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it.
-// Blocks until the device list is synced or the timeout is reached.
-func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error {
- mu := u.mutex(userID)
- mu.Lock()
- err := u.db.MarkDeviceListStale(ctx, userID, true)
- mu.Unlock()
- if err != nil {
- return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err)
- }
- u.notifyWorkers(userID)
- return nil
-}
-
-// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest,
-// which assumes when /send 200 OKs that the device lists have been updated.
-func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
- isDeviceListStale, err := u.update(ctx, event)
- if err != nil {
- return err
- }
- if isDeviceListStale {
- // poke workers to handle stale device lists
- u.notifyWorkers(event.UserID)
- }
- return nil
-}
-
-func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) (bool, error) {
- mu := u.mutex(event.UserID)
- mu.Lock()
- defer mu.Unlock()
- // check if we have the prev IDs
- exists, err := u.db.PrevIDsExists(ctx, event.UserID, event.PrevID)
- if err != nil {
- return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
- }
- // if this is the first time we're hearing about this user, sync the device list manually.
- if len(event.PrevID) == 0 {
- exists = false
- }
- util.GetLogger(ctx).WithFields(logrus.Fields{
- "prev_ids_exist": exists,
- "user_id": event.UserID,
- "device_id": event.DeviceID,
- "stream_id": event.StreamID,
- "prev_ids": event.PrevID,
- "display_name": event.DeviceDisplayName,
- "deleted": event.Deleted,
- }).Trace("DeviceListUpdater.Update")
-
- // if we haven't missed anything update the database and notify users
- if exists || event.Deleted {
- k := event.Keys
- if event.Deleted {
- k = nil
- }
- keys := []api.DeviceMessage{
- {
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: event.DeviceID,
- DisplayName: event.DeviceDisplayName,
- KeyJSON: k,
- UserID: event.UserID,
- },
- StreamID: event.StreamID,
- },
- }
-
- // DeviceKeysJSON will side-effect modify this, so it needs
- // to be a copy, not sharing any pointers with the above.
- deviceKeysCopy := *keys[0].DeviceKeys
- deviceKeysCopy.KeyJSON = nil
- existingKeys := []api.DeviceMessage{
- {
- Type: keys[0].Type,
- DeviceKeys: &deviceKeysCopy,
- StreamID: keys[0].StreamID,
- },
- }
-
- // fetch what keys we had already and only emit changes
- if err = u.db.DeviceKeysJSON(ctx, existingKeys); err != nil {
- // non-fatal, log and continue
- util.GetLogger(ctx).WithError(err).WithField("user_id", event.UserID).Errorf(
- "failed to query device keys json for calculating diffs",
- )
- }
-
- err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
- if err != nil {
- return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
- }
-
- if err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false); err != nil {
- return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err)
- }
- return false, nil
- }
-
- err = u.db.MarkDeviceListStale(ctx, event.UserID, true)
- if err != nil {
- return false, fmt.Errorf("failed to mark device list for %s as stale: %w", event.UserID, err)
- }
-
- return true, nil
-}
-
-func (u *DeviceListUpdater) notifyWorkers(userID string) {
- _, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
- if err != nil {
- return
- }
- hash := fnv.New32a()
- _, _ = hash.Write([]byte(remoteServer))
- index := int(int64(hash.Sum32()) % int64(len(u.workerChans)))
-
- ch := u.assignChannel(userID)
- u.workerChans[index] <- remoteServer
- select {
- case <-ch:
- case <-time.After(10 * time.Second):
- // we don't return an error in this case as it's not a failure condition.
- // we mainly block for the benefit of sytest anyway
- }
-}
-
-func (u *DeviceListUpdater) assignChannel(userID string) chan bool {
- u.userIDToChanMu.Lock()
- defer u.userIDToChanMu.Unlock()
- if ch, ok := u.userIDToChan[userID]; ok {
- return ch
- }
- ch := make(chan bool)
- u.userIDToChan[userID] = ch
- return ch
-}
-
-func (u *DeviceListUpdater) clearChannel(userID string) {
- u.userIDToChanMu.Lock()
- defer u.userIDToChanMu.Unlock()
- if ch, ok := u.userIDToChan[userID]; ok {
- close(ch)
- delete(u.userIDToChan, userID)
- }
-}
-
-func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
- retries := make(map[gomatrixserverlib.ServerName]time.Time)
- retriesMu := &sync.Mutex{}
- // restarter goroutine which will inject failed servers into ch when it is time
- go func() {
- var serversToRetry []gomatrixserverlib.ServerName
- for {
- serversToRetry = serversToRetry[:0] // reuse memory
- time.Sleep(time.Second)
- retriesMu.Lock()
- now := time.Now()
- for srv, retryAt := range retries {
- if now.After(retryAt) {
- serversToRetry = append(serversToRetry, srv)
- }
- }
- for _, srv := range serversToRetry {
- delete(retries, srv)
- }
- retriesMu.Unlock()
- for _, srv := range serversToRetry {
- ch <- srv
- }
- }
- }()
- for serverName := range ch {
- retriesMu.Lock()
- _, exists := retries[serverName]
- retriesMu.Unlock()
- if exists {
- // Don't retry a server that we're already waiting for.
- continue
- }
- waitTime, shouldRetry := u.processServer(serverName)
- if shouldRetry {
- retriesMu.Lock()
- if _, exists = retries[serverName]; !exists {
- retries[serverName] = time.Now().Add(waitTime)
- }
- retriesMu.Unlock()
- }
- }
-}
-
-func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) {
- ctx := u.process.Context()
- logger := util.GetLogger(ctx).WithField("server_name", serverName)
- deviceListUpdateCount.WithLabelValues(string(serverName)).Inc()
-
- waitTime := defaultWaitTime // How long should we wait to try again?
- successCount := 0 // How many user requests failed?
-
- userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName})
- if err != nil {
- logger.WithError(err).Error("Failed to load stale device lists")
- return waitTime, true
- }
-
- defer func() {
- for _, userID := range userIDs {
- // always clear the channel to unblock Update calls regardless of success/failure
- u.clearChannel(userID)
- }
- }()
-
- for _, userID := range userIDs {
- userWait, err := u.processServerUser(ctx, serverName, userID)
- if err != nil {
- if userWait > waitTime {
- waitTime = userWait
- }
- break
- }
- successCount++
- }
-
- allUsersSucceeded := successCount == len(userIDs)
- if !allUsersSucceeded {
- logger.WithFields(logrus.Fields{
- "total": len(userIDs),
- "succeeded": successCount,
- "failed": len(userIDs) - successCount,
- "wait_time": waitTime,
- }).Debug("Failed to query device keys for some users")
- }
- return waitTime, !allUsersSucceeded
-}
-
-func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) {
- ctx, cancel := context.WithTimeout(ctx, requestTimeout)
- defer cancel()
- logger := util.GetLogger(ctx).WithFields(logrus.Fields{
- "server_name": serverName,
- "user_id": userID,
- })
- res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID)
- if err != nil {
- if errors.Is(err, context.DeadlineExceeded) {
- return time.Minute * 10, err
- }
- switch e := err.(type) {
- case *json.UnmarshalTypeError, *json.SyntaxError:
- logger.WithError(err).Debugf("Device list update for %q contained invalid JSON", userID)
- return defaultWaitTime, nil
- case *fedsenderapi.FederationClientError:
- if e.RetryAfter > 0 {
- return e.RetryAfter, err
- } else if e.Blacklisted {
- return time.Hour * 8, err
- }
- case net.Error:
- // Use the default waitTime, if it's a timeout.
- // It probably doesn't make sense to try further users.
- if !e.Timeout() {
- logger.WithError(e).Debug("GetUserDevices returned net.Error")
- return time.Minute * 10, err
- }
- case gomatrix.HTTPError:
- // The remote server returned an error, give it some time to recover.
- // This is to avoid spamming remote servers, which may not be Matrix servers anymore.
- if e.Code >= 300 {
- logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError")
- return hourWaitTime, err
- }
- default:
- // Something else failed
- logger.WithError(err).Debugf("GetUserDevices returned unknown error type: %T", err)
- return time.Minute * 10, err
- }
- }
- if res.UserID != userID {
- logger.WithError(err).Debugf("User ID %q in device list update response doesn't match expected %q", res.UserID, userID)
- return defaultWaitTime, nil
- }
- if res.MasterKey != nil || res.SelfSigningKey != nil {
- uploadReq := &api.PerformUploadDeviceKeysRequest{
- UserID: userID,
- }
- uploadRes := &api.PerformUploadDeviceKeysResponse{}
- if res.MasterKey != nil {
- if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil {
- uploadReq.MasterKey = *res.MasterKey
- }
- }
- if res.SelfSigningKey != nil {
- if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil {
- uploadReq.SelfSigningKey = *res.SelfSigningKey
- }
- }
- _ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
- }
- err = u.updateDeviceList(&res)
- if err != nil {
- logger.WithError(err).Error("Fetched device list but failed to store/emit it")
- return defaultWaitTime, err
- }
- return defaultWaitTime, nil
-}
-
-func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error {
- ctx := context.Background() // we've got the keys, don't time out when persisting them to the database.
- keys := make([]api.DeviceMessage, len(res.Devices))
- existingKeys := make([]api.DeviceMessage, len(res.Devices))
- for i, device := range res.Devices {
- keyJSON, err := json.Marshal(device.Keys)
- if err != nil {
- util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device")
- continue
- }
- keys[i] = api.DeviceMessage{
- Type: api.TypeDeviceKeyUpdate,
- StreamID: res.StreamID,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: device.DeviceID,
- DisplayName: device.DisplayName,
- UserID: res.UserID,
- KeyJSON: keyJSON,
- },
- }
- existingKeys[i] = api.DeviceMessage{
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- UserID: res.UserID,
- DeviceID: device.DeviceID,
- },
- }
- }
- // fetch what keys we had already and only emit changes
- if err := u.db.DeviceKeysJSON(ctx, existingKeys); err != nil {
- // non-fatal, log and continue
- util.GetLogger(ctx).WithError(err).WithField("user_id", res.UserID).Errorf(
- "failed to query device keys json for calculating diffs",
- )
- }
-
- err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID})
- if err != nil {
- return fmt.Errorf("failed to store remote device keys: %w", err)
- }
- err = u.db.MarkDeviceListStale(ctx, res.UserID, false)
- if err != nil {
- return fmt.Errorf("failed to mark device list as fresh: %w", err)
- }
- err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false)
- if err != nil {
- return fmt.Errorf("failed to emit key changes for fresh device list: %w", err)
- }
- return nil
-}
diff --git a/keyserver/internal/device_list_update_default.go b/keyserver/internal/device_list_update_default.go
deleted file mode 100644
index 7d357c95..00000000
--- a/keyserver/internal/device_list_update_default.go
+++ /dev/null
@@ -1,22 +0,0 @@
-// Copyright 2022 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//go:build !vw
-
-package internal
-
-import "time"
-
-const defaultWaitTime = time.Minute
-const hourWaitTime = time.Hour
diff --git a/keyserver/internal/device_list_update_sytest.go b/keyserver/internal/device_list_update_sytest.go
deleted file mode 100644
index 1c60d2eb..00000000
--- a/keyserver/internal/device_list_update_sytest.go
+++ /dev/null
@@ -1,25 +0,0 @@
-// Copyright 2022 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//go:build vw
-
-package internal
-
-import "time"
-
-// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite
-// results in a one-hour wait time from a previous device so the test times out. This is fine for
-// production, but makes an otherwise passing test fail.
-const defaultWaitTime = time.Second
-const hourWaitTime = time.Second
diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go
deleted file mode 100644
index 60a2c2f3..00000000
--- a/keyserver/internal/device_list_update_test.go
+++ /dev/null
@@ -1,431 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package internal
-
-import (
- "context"
- "crypto/ed25519"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "reflect"
- "strings"
- "sync"
- "testing"
- "time"
-
- "github.com/matrix-org/gomatrixserverlib"
-
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage"
- roomserver "github.com/matrix-org/dendrite/roomserver/api"
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/setup/process"
- "github.com/matrix-org/dendrite/test"
- "github.com/matrix-org/dendrite/test/testrig"
-)
-
-var (
- ctx = context.Background()
-)
-
-type mockKeyChangeProducer struct {
- events []api.DeviceMessage
-}
-
-func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error {
- p.events = append(p.events, keys...)
- return nil
-}
-
-type mockDeviceListUpdaterDatabase struct {
- staleUsers map[string]bool
- prevIDsExist func(string, []int64) bool
- storedKeys []api.DeviceMessage
- mu sync.Mutex // protect staleUsers
-}
-
-func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error {
- return nil
-}
-
-// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
-// If no domains are given, all user IDs with stale device lists are returned.
-func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
- d.mu.Lock()
- defer d.mu.Unlock()
- var result []string
- for userID, isStale := range d.staleUsers {
- if !isStale {
- continue
- }
- _, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
- if err != nil {
- return nil, err
- }
- if len(domains) == 0 {
- result = append(result, userID)
- continue
- }
- for _, d := range domains {
- if remoteServer == d {
- result = append(result, userID)
- break
- }
- }
- }
- return result, nil
-}
-
-// MarkDeviceListStale sets the stale bit for this user to isStale.
-func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
- d.mu.Lock()
- defer d.mu.Unlock()
- d.staleUsers[userID] = isStale
- return nil
-}
-
-func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool {
- d.mu.Lock()
- defer d.mu.Unlock()
- return d.staleUsers[userID]
-}
-
-// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
-// for this (user, device). Does not modify the stream ID for keys.
-func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
- d.storedKeys = append(d.storedKeys, keys...)
- return nil
-}
-
-// PrevIDsExists returns true if all prev IDs exist for this user.
-func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
- return d.prevIDsExist(userID, prevIDs), nil
-}
-
-func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
- return nil
-}
-
-type mockDeviceListUpdaterAPI struct {
-}
-
-func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
- return nil
-}
-
-type roundTripper struct {
- fn func(*http.Request) (*http.Response, error)
-}
-
-func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
- return t.fn(req)
-}
-
-func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
- _, pkey, _ := ed25519.GenerateKey(nil)
- fedClient := gomatrixserverlib.NewFederationClient(
- []*gomatrixserverlib.SigningIdentity{
- {
- ServerName: gomatrixserverlib.ServerName("example.test"),
- KeyID: gomatrixserverlib.KeyID("ed25519:test"),
- PrivateKey: pkey,
- },
- },
- )
- fedClient.Client = *gomatrixserverlib.NewClient(
- gomatrixserverlib.WithTransport(&roundTripper{tripper}),
- )
- return fedClient
-}
-
-// Test that the device keys get persisted and emitted if we have the previous IDs.
-func TestUpdateHavePrevID(t *testing.T) {
- db := &mockDeviceListUpdaterDatabase{
- staleUsers: make(map[string]bool),
- prevIDsExist: func(string, []int64) bool {
- return true
- },
- }
- ap := &mockDeviceListUpdaterAPI{}
- producer := &mockKeyChangeProducer{}
- updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost")
- event := gomatrixserverlib.DeviceListUpdateEvent{
- DeviceDisplayName: "Foo Bar",
- Deleted: false,
- DeviceID: "FOO",
- Keys: []byte(`{"key":"value"}`),
- PrevID: []int64{0},
- StreamID: 1,
- UserID: "@alice:localhost",
- }
- err := updater.Update(ctx, event)
- if err != nil {
- t.Fatalf("Update returned an error: %s", err)
- }
- want := api.DeviceMessage{
- Type: api.TypeDeviceKeyUpdate,
- StreamID: event.StreamID,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: event.DeviceID,
- DisplayName: event.DeviceDisplayName,
- KeyJSON: event.Keys,
- UserID: event.UserID,
- },
- }
- if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
- t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
- }
- if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
- t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
- }
- if db.isStale(event.UserID) {
- t.Errorf("%s incorrectly marked as stale", event.UserID)
- }
-}
-
-// Test that device keys are fetched from the remote server if we are missing prev IDs
-// and that the user's devices are marked as stale until it succeeds.
-func TestUpdateNoPrevID(t *testing.T) {
- db := &mockDeviceListUpdaterDatabase{
- staleUsers: make(map[string]bool),
- prevIDsExist: func(string, []int64) bool {
- return false
- },
- }
- ap := &mockDeviceListUpdaterAPI{}
- producer := &mockKeyChangeProducer{}
- remoteUserID := "@alice:example.somewhere"
- var wg sync.WaitGroup
- wg.Add(1)
- keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
- fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
- defer wg.Done()
- if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) {
- return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
- }
- return &http.Response{
- StatusCode: 200,
- Body: io.NopCloser(strings.NewReader(`
- {
- "user_id": "` + remoteUserID + `",
- "stream_id": 5,
- "devices": [
- {
- "device_id": "JLAFKJWSCS",
- "keys": ` + keyJSON + `,
- "device_display_name": "Mobile Phone"
- }
- ]
- }
- `)),
- }, nil
- })
- updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test")
- if err := updater.Start(); err != nil {
- t.Fatalf("failed to start updater: %s", err)
- }
- event := gomatrixserverlib.DeviceListUpdateEvent{
- DeviceDisplayName: "Mobile Phone",
- Deleted: false,
- DeviceID: "another_device_id",
- Keys: []byte(`{"key":"value"}`),
- PrevID: []int64{3},
- StreamID: 4,
- UserID: remoteUserID,
- }
- err := updater.Update(ctx, event)
-
- if err != nil {
- t.Fatalf("Update returned an error: %s", err)
- }
- t.Log("waiting for /users/devices to be called...")
- wg.Wait()
- // wait a bit for db to be updated...
- time.Sleep(100 * time.Millisecond)
- want := api.DeviceMessage{
- Type: api.TypeDeviceKeyUpdate,
- StreamID: 5,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: "JLAFKJWSCS",
- DisplayName: "Mobile Phone",
- UserID: remoteUserID,
- KeyJSON: []byte(keyJSON),
- },
- }
- // Now we should have a fresh list and the keys and emitted something
- if db.isStale(event.UserID) {
- t.Errorf("%s still marked as stale", event.UserID)
- }
- if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
- t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON))
- t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
- }
- if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
- t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
- }
-
-}
-
-// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the
-// update is still ongoing.
-func TestDebounce(t *testing.T) {
- t.Skipf("panic on closed channel on GHA")
- db := &mockDeviceListUpdaterDatabase{
- staleUsers: make(map[string]bool),
- prevIDsExist: func(string, []int64) bool {
- return true
- },
- }
- ap := &mockDeviceListUpdaterAPI{}
- producer := &mockKeyChangeProducer{}
- fedCh := make(chan *http.Response, 1)
- srv := gomatrixserverlib.ServerName("example.com")
- userID := "@alice:example.com"
- keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
- incomingFedReq := make(chan struct{})
- fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
- if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) {
- return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
- }
- close(incomingFedReq)
- return <-fedCh, nil
- })
- updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost")
- if err := updater.Start(); err != nil {
- t.Fatalf("failed to start updater: %s", err)
- }
-
- // hit this 5 times
- var wg sync.WaitGroup
- wg.Add(5)
- for i := 0; i < 5; i++ {
- go func() {
- defer wg.Done()
- if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil {
- t.Errorf("ManualUpdate: %s", err)
- }
- }()
- }
-
- // wait until the updater hits federation
- select {
- case <-incomingFedReq:
- case <-time.After(time.Second):
- t.Fatalf("timed out waiting for updater to hit federation")
- }
-
- // user should be marked as stale
- if !db.isStale(userID) {
- t.Errorf("user %s not marked as stale", userID)
- }
- // now send the response over federation
- fedCh <- &http.Response{
- StatusCode: 200,
- Body: io.NopCloser(strings.NewReader(`
- {
- "user_id": "` + userID + `",
- "stream_id": 5,
- "devices": [
- {
- "device_id": "JLAFKJWSCS",
- "keys": ` + keyJSON + `,
- "device_display_name": "Mobile Phone"
- }
- ]
- }
- `)),
- }
- close(fedCh)
- // wait until all 5 ManualUpdates return. If we hit federation again we won't send a response
- // and should panic with read on a closed channel
- wg.Wait()
-
- // user is no longer stale now
- if db.isStale(userID) {
- t.Errorf("user %s is marked as stale", userID)
- }
-}
-
-func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) {
- t.Helper()
-
- base, _, _ := testrig.Base(nil)
- connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
- db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
- if err != nil {
- t.Fatal(err)
- }
-
- return db, clearDB
-}
-
-type mockKeyserverRoomserverAPI struct {
- leftUsers []string
-}
-
-func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
- res.LeftUsers = m.leftUsers
- return nil
-}
-
-func TestDeviceListUpdater_CleanUp(t *testing.T) {
- processCtx := process.NewProcessContext()
-
- alice := test.NewUser(t)
- bob := test.NewUser(t)
-
- // Bob is not joined to any of our rooms
- rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}}
-
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, clearDB := mustCreateKeyserverDB(t, dbType)
- defer clearDB()
-
- // This should not get deleted
- if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil {
- t.Error(err)
- }
-
- // this one should get deleted
- if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil {
- t.Error(err)
- }
-
- updater := NewDeviceListUpdater(processCtx, db, nil,
- nil, nil,
- 0, rsAPI, "test")
- if err := updater.CleanUp(); err != nil {
- t.Error(err)
- }
-
- // check that we still have Alice in our stale list
- staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
- if err != nil {
- t.Error(err)
- }
-
- // There should only be Alice
- wantCount := 1
- if count := len(staleUsers); count != wantCount {
- t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count)
- }
-
- if staleUsers[0] != alice.ID {
- t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID)
- }
- })
-}
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
deleted file mode 100644
index 9a08a0bb..00000000
--- a/keyserver/internal/internal.go
+++ /dev/null
@@ -1,816 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package internal
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "sync"
- "time"
-
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/util"
- "github.com/sirupsen/logrus"
- "github.com/tidwall/gjson"
- "github.com/tidwall/sjson"
-
- fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/producers"
- "github.com/matrix-org/dendrite/keyserver/storage"
- "github.com/matrix-org/dendrite/setup/config"
- userapi "github.com/matrix-org/dendrite/userapi/api"
-)
-
-type KeyInternalAPI struct {
- DB storage.Database
- Cfg *config.KeyServer
- FedClient fedsenderapi.KeyserverFederationAPI
- UserAPI userapi.KeyserverUserAPI
- Producer *producers.KeyChange
- Updater *DeviceListUpdater
-}
-
-func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) {
- a.UserAPI = i
-}
-
-func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
- userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset)
- if err != nil {
- res.Error = &api.KeyError{
- Err: err.Error(),
- }
- return nil
- }
- res.Offset = latest
- res.UserIDs = userIDs
- return nil
-}
-
-func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
- res.KeyErrors = make(map[string]map[string]*api.KeyError)
- if len(req.DeviceKeys) > 0 {
- a.uploadLocalDeviceKeys(ctx, req, res)
- }
- if len(req.OneTimeKeys) > 0 {
- a.uploadOneTimeKeys(ctx, req, res)
- }
- otks, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
- if err != nil {
- return err
- }
- res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
- return nil
-}
-
-func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
- res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
- res.Failures = make(map[string]interface{})
- // wrap request map in a top-level by-domain map
- domainToDeviceKeys := make(map[string]map[string]map[string]string)
- for userID, val := range req.OneTimeKeys {
- _, serverName, err := gomatrixserverlib.SplitID('@', userID)
- if err != nil {
- continue // ignore invalid users
- }
- nested, ok := domainToDeviceKeys[string(serverName)]
- if !ok {
- nested = make(map[string]map[string]string)
- }
- nested[userID] = val
- domainToDeviceKeys[string(serverName)] = nested
- }
- for domain, local := range domainToDeviceKeys {
- if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
- continue
- }
- // claim local keys
- keys, err := a.DB.ClaimKeys(ctx, local)
- if err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
- }
- }
- util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys")
- for _, key := range keys {
- _, ok := res.OneTimeKeys[key.UserID]
- if !ok {
- res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage)
- }
- _, ok = res.OneTimeKeys[key.UserID][key.DeviceID]
- if !ok {
- res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
- }
- for keyID, keyJSON := range key.KeyJSON {
- res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
- }
- }
- delete(domainToDeviceKeys, domain)
- }
- if len(domainToDeviceKeys) > 0 {
- a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
- }
- return nil
-}
-
-func (a *KeyInternalAPI) claimRemoteKeys(
- ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
-) {
- var wg sync.WaitGroup // Wait for fan-out goroutines to finish
- var mu sync.Mutex // Protects the response struct
- var claimed int // Number of keys claimed in total
- var failures int // Number of servers we failed to ask
-
- util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys))
- wg.Add(len(domainToDeviceKeys))
-
- for d, k := range domainToDeviceKeys {
- go func(domain string, keysToClaim map[string]map[string]string) {
- fedCtx, cancel := context.WithTimeout(ctx, timeout)
- defer cancel()
- defer wg.Done()
-
- claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
-
- mu.Lock()
- defer mu.Unlock()
-
- if err != nil {
- util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
- res.Failures[domain] = map[string]interface{}{
- "message": err.Error(),
- }
- failures++
- return
- }
-
- for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys {
- res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
- for deviceID, keys := range deviceIDToKeys {
- res.OneTimeKeys[userID][deviceID] = keys
- claimed += len(keys)
- }
- }
- }(d, k)
- }
-
- wg.Wait()
- util.GetLogger(ctx).WithFields(logrus.Fields{
- "num_keys": claimed,
- "num_failures": failures,
- }).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys))
-}
-
-func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
- if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("Failed to delete device keys: %s", err),
- }
- }
- return nil
-}
-
-func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
- count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
- if err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
- }
- return nil
- }
- res.Count = *count
- return nil
-}
-
-func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
- msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
- if err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
- }
- return nil
- }
- maxStreamID := int64(0)
- // remove deleted devices
- var result []api.DeviceMessage
- for _, m := range msgs {
- if m.StreamID > maxStreamID {
- maxStreamID = m.StreamID
- }
- if m.KeyJSON == nil || len(m.KeyJSON) == 0 {
- continue
- }
- result = append(result, m)
- }
- res.Devices = result
- res.StreamID = maxStreamID
- return nil
-}
-
-// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present
-// in our database.
-func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error {
- knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{}, true)
- if err != nil {
- return err
- }
- if len(knownDevices) == 0 {
- return nil // fmt.Errorf("unknown user %s", req.UserID)
- }
-
- for i := range knownDevices {
- if knownDevices[i].DeviceID == req.DeviceID {
- return nil // we already know about this device
- }
- }
-
- return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID)
-}
-
-// nolint:gocyclo
-func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
- var respMu sync.Mutex
- res.DeviceKeys = make(map[string]map[string]json.RawMessage)
- res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
- res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
- res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
- res.Failures = make(map[string]interface{})
-
- // make a map from domain to device keys
- domainToDeviceKeys := make(map[string]map[string][]string)
- domainToCrossSigningKeys := make(map[string]map[string]struct{})
- for userID, deviceIDs := range req.UserToDevices {
- _, serverName, err := gomatrixserverlib.SplitID('@', userID)
- if err != nil {
- continue // ignore invalid users
- }
- domain := string(serverName)
- // query local devices
- if a.Cfg.Matrix.IsLocalServerName(serverName) {
- deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
- if err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("failed to query local device keys: %s", err),
- }
- return nil
- }
-
- // pull out display names after we have the keys so we handle wildcards correctly
- var dids []string
- for _, dk := range deviceKeys {
- dids = append(dids, dk.DeviceID)
- }
- var queryRes userapi.QueryDeviceInfosResponse
- err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{
- DeviceIDs: dids,
- }, &queryRes)
- if err != nil {
- util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
- }
-
- if res.DeviceKeys[userID] == nil {
- res.DeviceKeys[userID] = make(map[string]json.RawMessage)
- }
- for _, dk := range deviceKeys {
- if len(dk.KeyJSON) == 0 {
- continue // don't include blank keys
- }
- // inject display name if known (either locally or remotely)
- displayName := dk.DisplayName
- if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
- displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
- }
- dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
- DisplayName string `json:"device_display_name,omitempty"`
- }{displayName})
- res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
- }
- } else {
- domainToDeviceKeys[domain] = make(map[string][]string)
- domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
- }
- // work out if our cross-signing request for this user was
- // satisfied, if not add them to the list of things to fetch
- if _, ok := res.MasterKeys[userID]; !ok {
- if _, ok := domainToCrossSigningKeys[domain]; !ok {
- domainToCrossSigningKeys[domain] = make(map[string]struct{})
- }
- domainToCrossSigningKeys[domain][userID] = struct{}{}
- }
- if _, ok := res.SelfSigningKeys[userID]; !ok {
- if _, ok := domainToCrossSigningKeys[domain]; !ok {
- domainToCrossSigningKeys[domain] = make(map[string]struct{})
- }
- domainToCrossSigningKeys[domain][userID] = struct{}{}
- }
- }
-
- // attempt to satisfy key queries from the local database first as we should get device updates pushed to us
- domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
- if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
- // perform key queries for remote devices
- a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
- }
-
- // Now that we've done the potentially expensive work of asking the federation,
- // try filling the cross-signing keys from the database that we know about.
- a.crossSigningKeysFromDatabase(ctx, req, res)
-
- // Finally, append signatures that we know about
- // TODO: This is horrible because we need to round-trip the signature from
- // JSON, add the signatures and marshal it again, for some reason?
-
- for targetUserID, masterKey := range res.MasterKeys {
- if masterKey.Signatures == nil {
- masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
- }
- for targetKeyID := range masterKey.Keys {
- sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
- if err != nil {
- // Stop executing the function if the context was canceled/the deadline was exceeded,
- // as we can't continue without a valid context.
- if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
- return nil
- }
- logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
- continue
- }
- if len(sigMap) == 0 {
- continue
- }
- for sourceUserID, forSourceUser := range sigMap {
- for sourceKeyID, sourceSig := range forSourceUser {
- if _, ok := masterKey.Signatures[sourceUserID]; !ok {
- masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
- }
- masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
- }
- }
- }
- }
-
- for targetUserID, forUserID := range res.DeviceKeys {
- for targetKeyID, key := range forUserID {
- sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
- if err != nil {
- // Stop executing the function if the context was canceled/the deadline was exceeded,
- // as we can't continue without a valid context.
- if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
- return nil
- }
- logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
- continue
- }
- if len(sigMap) == 0 {
- continue
- }
- var deviceKey gomatrixserverlib.DeviceKeys
- if err = json.Unmarshal(key, &deviceKey); err != nil {
- continue
- }
- if deviceKey.Signatures == nil {
- deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
- }
- for sourceUserID, forSourceUser := range sigMap {
- for sourceKeyID, sourceSig := range forSourceUser {
- if _, ok := deviceKey.Signatures[sourceUserID]; !ok {
- deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
- }
- deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
- }
- }
- if js, err := json.Marshal(deviceKey); err == nil {
- res.DeviceKeys[targetUserID][targetKeyID] = js
- }
- }
- }
- return nil
-}
-
-func (a *KeyInternalAPI) remoteKeysFromDatabase(
- ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
-) map[string]map[string][]string {
- fetchRemote := make(map[string]map[string][]string)
- for domain, userToDeviceMap := range domainToDeviceKeys {
- for userID, deviceIDs := range userToDeviceMap {
- // we can't safely return keys from the db when all devices are requested as we don't
- // know if one has just been added.
- if len(deviceIDs) > 0 {
- err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
- if err == nil {
- continue
- }
- util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase")
- }
- // fetch device lists from remote
- if _, ok := fetchRemote[domain]; !ok {
- fetchRemote[domain] = make(map[string][]string)
- }
- fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
-
- }
- }
- return fetchRemote
-}
-
-func (a *KeyInternalAPI) queryRemoteKeys(
- ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse,
- domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{},
-) {
- resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys))
- // allows us to wait until all federation servers have been poked
- var wg sync.WaitGroup
- // mutex for writing directly to res (e.g failures)
- var respMu sync.Mutex
-
- domains := map[string]struct{}{}
- for domain := range domainToDeviceKeys {
- if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
- continue
- }
- domains[domain] = struct{}{}
- }
- for domain := range domainToCrossSigningKeys {
- if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
- continue
- }
- domains[domain] = struct{}{}
- }
- wg.Add(len(domains))
-
- // fan out
- for domain := range domains {
- go a.queryRemoteKeysOnServer(
- ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain],
- &wg, &respMu, timeout, resultCh, res,
- )
- }
-
- // Close the result channel when the goroutines have quit so the for .. range exits
- go func() {
- wg.Wait()
- close(resultCh)
- }()
-
- processResult := func(result *gomatrixserverlib.RespQueryKeys) {
- respMu.Lock()
- defer respMu.Unlock()
- for userID, nest := range result.DeviceKeys {
- res.DeviceKeys[userID] = make(map[string]json.RawMessage)
- for deviceID, deviceKey := range nest {
- keyJSON, err := json.Marshal(deviceKey)
- if err != nil {
- continue
- }
- res.DeviceKeys[userID][deviceID] = keyJSON
- }
- }
-
- for userID, body := range result.MasterKeys {
- res.MasterKeys[userID] = body
- }
-
- for userID, body := range result.SelfSigningKeys {
- res.SelfSigningKeys[userID] = body
- }
-
- // TODO: do we want to persist these somewhere now
- // that we have fetched them?
- }
-
- for result := range resultCh {
- processResult(result)
- }
-}
-
-func (a *KeyInternalAPI) queryRemoteKeysOnServer(
- ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{},
- wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
- res *api.QueryKeysResponse,
-) {
- defer wg.Done()
- fedCtx := ctx
- if timeout > 0 {
- var cancel context.CancelFunc
- fedCtx, cancel = context.WithTimeout(ctx, timeout)
- defer cancel()
- }
- // for users who we do not have any knowledge about, try to start doing device list updates for them
- // by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
- // lack a stream ID.
- userIDsForAllDevices := map[string]struct{}{}
- for userID, deviceIDs := range devKeys {
- if len(deviceIDs) == 0 {
- userIDsForAllDevices[userID] = struct{}{}
- }
- }
- // for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing
- // a device list update, so we'll populate those back into the /keys/query list if not
- for userID := range crossSigningKeys {
- if devKeys == nil {
- devKeys = map[string][]string{}
- }
- if _, ok := userIDsForAllDevices[userID]; !ok {
- devKeys[userID] = []string{}
- }
- }
- for userID := range userIDsForAllDevices {
- err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
- if err != nil {
- logrus.WithFields(logrus.Fields{
- logrus.ErrorKey: err,
- "user_id": userID,
- "server": serverName,
- }).Error("Failed to manually update device lists for user")
- // try to do it via /keys/query
- devKeys[userID] = []string{}
- continue
- }
- // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
- // user so the fact that we're populating all devices here isn't a problem so long as we have devices.
- err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
- if err != nil {
- logrus.WithFields(logrus.Fields{
- logrus.ErrorKey: err,
- "user_id": userID,
- "server": serverName,
- }).Error("Failed to manually update device lists for user")
- // try to do it via /keys/query
- devKeys[userID] = []string{}
- continue
- }
- }
- if len(devKeys) == 0 {
- return
- }
- queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
- if err == nil {
- resultCh <- &queryKeysResp
- return
- }
- respMu.Lock()
- res.Failures[serverName] = map[string]interface{}{
- "message": err.Error(),
- }
- respMu.Unlock()
-
- // last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server
- // is down, better to return something than nothing at all. Clients can know about the failure by
- // inspecting the failures map though so they can know it's a cached response.
- for userID, dkeys := range devKeys {
- // drop the error as it's already a failure at this point
- _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
- }
-
- // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
- respMu.Lock()
- if len(res.DeviceKeys) > 0 {
- delete(res.Failures, serverName)
- }
- respMu.Unlock()
-}
-
-func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
- ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
-) error {
- keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
- // if we can't query the db or there are fewer keys than requested, fetch from remote.
- if err != nil {
- return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
- }
- if len(keys) < len(deviceIDs) {
- return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID)
- }
- if len(deviceIDs) == 0 && len(keys) == 0 {
- return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
- }
- respMu.Lock()
- if res.DeviceKeys[userID] == nil {
- res.DeviceKeys[userID] = make(map[string]json.RawMessage)
- }
- respMu.Unlock()
-
- for _, key := range keys {
- if len(key.KeyJSON) == 0 {
- continue // ignore deleted keys
- }
- // inject the display name
- key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
- DisplayName string `json:"device_display_name,omitempty"`
- }{key.DisplayName})
- respMu.Lock()
- res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
- respMu.Unlock()
- }
- return nil
-}
-
-func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
- // get a list of devices from the user API that actually exist, as
- // we won't store keys for devices that don't exist
- uapidevices := &userapi.QueryDevicesResponse{}
- if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
- res.Error = &api.KeyError{
- Err: err.Error(),
- }
- return
- }
- if !uapidevices.UserExists {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("user %q does not exist", req.UserID),
- }
- return
- }
- existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
- for _, key := range uapidevices.Devices {
- existingDeviceMap[key.ID] = struct{}{}
- }
-
- // Get all of the user existing device keys so we can check for changes.
- existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
- if err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
- }
- return
- }
-
- // Work out whether we have device keys in the keyserver for devices that
- // no longer exist in the user API. This is mostly an exercise to ensure
- // that we keep some integrity between the two.
- var toClean []gomatrixserverlib.KeyID
- for _, k := range existingKeys {
- if _, ok := existingDeviceMap[k.DeviceID]; !ok {
- toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
- }
- }
-
- if len(toClean) > 0 {
- if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
- logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
- } else {
- logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
- }
- }
-
- var keysToStore []api.DeviceMessage
-
- if req.OnlyDisplayNameUpdates {
- for _, existingKey := range existingKeys {
- for _, newKey := range req.DeviceKeys {
- switch {
- case existingKey.UserID != newKey.UserID:
- continue
- case existingKey.DeviceID != newKey.DeviceID:
- continue
- case existingKey.DisplayName != newKey.DisplayName:
- existingKey.DisplayName = newKey.DisplayName
- }
- }
- keysToStore = append(keysToStore, existingKey)
- }
- } else {
- // assert that the user ID / device ID are not lying for each key
- for _, key := range req.DeviceKeys {
- var serverName gomatrixserverlib.ServerName
- _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
- if err != nil {
- continue // ignore invalid users
- }
- if !a.Cfg.Matrix.IsLocalServerName(serverName) {
- continue // ignore remote users
- }
- if len(key.KeyJSON) == 0 {
- keysToStore = append(keysToStore, key.WithStreamID(0))
- continue // deleted keys don't need sanity checking
- }
- // check that the device in question actually exists in the user
- // API before we try and store a key for it
- if _, ok := existingDeviceMap[key.DeviceID]; !ok {
- continue
- }
- gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
- gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
- if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
- keysToStore = append(keysToStore, key.WithStreamID(0))
- continue
- }
-
- res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
- Err: fmt.Sprintf(
- "user_id or device_id mismatch: users: %s - %s, devices: %s - %s",
- gotUserID, key.UserID, gotDeviceID, key.DeviceID,
- ),
- })
- }
- }
-
- // store the device keys and emit changes
- err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
- if err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
- }
- return
- }
- err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates)
- if err != nil {
- util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
- }
-}
-
-func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
- if req.UserID == "" {
- res.Error = &api.KeyError{
- Err: "user ID missing",
- }
- }
- if req.DeviceID != "" && len(req.OneTimeKeys) == 0 {
- counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
- if err != nil {
- res.Error = &api.KeyError{
- Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err),
- }
- }
- if counts != nil {
- res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
- }
- return
- }
- for _, key := range req.OneTimeKeys {
- // grab existing keys based on (user/device/algorithm/key ID)
- keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
- i := 0
- for keyIDWithAlgo := range key.KeyJSON {
- keyIDsWithAlgorithms[i] = keyIDWithAlgo
- i++
- }
- existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
- if err != nil {
- res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
- Err: "failed to query existing one-time keys: " + err.Error(),
- })
- continue
- }
- for keyIDWithAlgo := range existingKeys {
- // if keys exist and the JSON doesn't match, error out as the key already exists
- if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
- res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
- Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
- })
- continue
- }
- }
- // store one-time keys
- counts, err := a.DB.StoreOneTimeKeys(ctx, key)
- if err != nil {
- res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
- Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()),
- })
- continue
- }
- // collect counts
- res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
- }
-
-}
-
-func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
- // if we only want to update the display names, we can skip the checks below
- if onlyUpdateDisplayName {
- return producer.ProduceKeyChanges(new)
- }
- // find keys in new that are not in existing
- var keysAdded []api.DeviceMessage
- for _, newKey := range new {
- exists := false
- for _, existingKey := range existing {
- // Do not treat the absence of keys as equal, or else we will not emit key changes
- // when users delete devices which never had a key to begin with as both KeyJSONs are nil.
- if existingKey.DeviceKeysEqual(&newKey) {
- exists = true
- break
- }
- }
- if !exists {
- keysAdded = append(keysAdded, newKey)
- }
- }
- return producer.ProduceKeyChanges(keysAdded)
-}
diff --git a/keyserver/internal/internal_test.go b/keyserver/internal/internal_test.go
deleted file mode 100644
index 8a2c9c5d..00000000
--- a/keyserver/internal/internal_test.go
+++ /dev/null
@@ -1,156 +0,0 @@
-package internal_test
-
-import (
- "context"
- "reflect"
- "testing"
-
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/internal"
- "github.com/matrix-org/dendrite/keyserver/storage"
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/test"
-)
-
-func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
- t.Helper()
- connStr, close := test.PrepareDBConnectionString(t, dbType)
- db, err := storage.NewDatabase(nil, &config.DatabaseOptions{
- ConnectionString: config.DataSource(connStr),
- })
- if err != nil {
- t.Fatalf("failed to create new user db: %v", err)
- }
- return db, close
-}
-
-func Test_QueryDeviceMessages(t *testing.T) {
- alice := test.NewUser(t)
- type args struct {
- req *api.QueryDeviceMessagesRequest
- res *api.QueryDeviceMessagesResponse
- }
- tests := []struct {
- name string
- args args
- wantErr bool
- want *api.QueryDeviceMessagesResponse
- }{
- {
- name: "no existing keys",
- args: args{
- req: &api.QueryDeviceMessagesRequest{
- UserID: "@doesNotExist:localhost",
- },
- res: &api.QueryDeviceMessagesResponse{},
- },
- want: &api.QueryDeviceMessagesResponse{},
- },
- {
- name: "existing user returns devices",
- args: args{
- req: &api.QueryDeviceMessagesRequest{
- UserID: alice.ID,
- },
- res: &api.QueryDeviceMessagesResponse{},
- },
- want: &api.QueryDeviceMessagesResponse{
- StreamID: 6,
- Devices: []api.DeviceMessage{
- {
- Type: api.TypeDeviceKeyUpdate, StreamID: 5, DeviceKeys: &api.DeviceKeys{
- DeviceID: "myDevice",
- DisplayName: "first device",
- UserID: alice.ID,
- KeyJSON: []byte("ghi"),
- },
- },
- {
- Type: api.TypeDeviceKeyUpdate, StreamID: 6, DeviceKeys: &api.DeviceKeys{
- DeviceID: "mySecondDevice",
- DisplayName: "second device",
- UserID: alice.ID,
- KeyJSON: []byte("jkl"),
- }, // streamID 6
- },
- },
- },
- },
- }
-
- deviceMessages := []api.DeviceMessage{
- { // not the user we're looking for
- Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
- UserID: "@doesNotExist:localhost",
- },
- // streamID 1 for this user
- },
- { // empty keyJSON will be ignored
- Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
- DeviceID: "myDevice",
- UserID: alice.ID,
- }, // streamID 1
- },
- {
- Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
- DeviceID: "myDevice",
- UserID: alice.ID,
- KeyJSON: []byte("abc"),
- }, // streamID 2
- },
- {
- Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
- DeviceID: "myDevice",
- UserID: alice.ID,
- KeyJSON: []byte("def"),
- }, // streamID 3
- },
- {
- Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
- DeviceID: "myDevice",
- UserID: alice.ID,
- KeyJSON: []byte(""),
- }, // streamID 4
- },
- {
- Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
- DeviceID: "myDevice",
- DisplayName: "first device",
- UserID: alice.ID,
- KeyJSON: []byte("ghi"),
- }, // streamID 5
- },
- {
- Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
- DeviceID: "mySecondDevice",
- UserID: alice.ID,
- KeyJSON: []byte("jkl"),
- DisplayName: "second device",
- }, // streamID 6
- },
- }
- ctx := context.Background()
-
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, closeDB := mustCreateDatabase(t, dbType)
- defer closeDB()
- if err := db.StoreLocalDeviceKeys(ctx, deviceMessages); err != nil {
- t.Fatalf("failed to store local devicesKeys")
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- a := &internal.KeyInternalAPI{
- DB: db,
- }
- if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr {
- t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr)
- }
- got := tt.args.res
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("QueryDeviceMessages(): got:\n%+v, want:\n%+v", got, tt.want)
- }
- })
- }
- })
-}
diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go
deleted file mode 100644
index 2d143682..00000000
--- a/keyserver/keyserver.go
+++ /dev/null
@@ -1,86 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package keyserver
-
-import (
- "github.com/sirupsen/logrus"
-
- rsapi "github.com/matrix-org/dendrite/roomserver/api"
-
- fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/consumers"
- "github.com/matrix-org/dendrite/keyserver/internal"
- "github.com/matrix-org/dendrite/keyserver/producers"
- "github.com/matrix-org/dendrite/keyserver/storage"
- "github.com/matrix-org/dendrite/setup/base"
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/setup/jetstream"
-)
-
-// NewInternalAPI returns a concerete implementation of the internal API. Callers
-// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
-func NewInternalAPI(
- base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
- rsAPI rsapi.KeyserverRoomserverAPI,
-) api.KeyInternalAPI {
- js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
-
- db, err := storage.NewDatabase(base, &cfg.Database)
- if err != nil {
- logrus.WithError(err).Panicf("failed to connect to key server database")
- }
-
- keyChangeProducer := &producers.KeyChange{
- Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)),
- JetStream: js,
- DB: db,
- }
- ap := &internal.KeyInternalAPI{
- DB: db,
- Cfg: cfg,
- FedClient: fedClient,
- Producer: keyChangeProducer,
- }
- updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable
- ap.Updater = updater
-
- // Remove users which we don't share a room with anymore
- if err := updater.CleanUp(); err != nil {
- logrus.WithError(err).Error("failed to cleanup stale device lists")
- }
-
- go func() {
- if err := updater.Start(); err != nil {
- logrus.WithError(err).Panicf("failed to start device list updater")
- }
- }()
-
- dlConsumer := consumers.NewDeviceListUpdateConsumer(
- base.ProcessContext, cfg, js, updater,
- )
- if err := dlConsumer.Start(); err != nil {
- logrus.WithError(err).Panic("failed to start device list consumer")
- }
-
- sigConsumer := consumers.NewSigningKeyUpdateConsumer(
- base.ProcessContext, cfg, js, ap,
- )
- if err := sigConsumer.Start(); err != nil {
- logrus.WithError(err).Panic("failed to start signing key consumer")
- }
-
- return ap
-}
diff --git a/keyserver/keyserver_test.go b/keyserver/keyserver_test.go
deleted file mode 100644
index 159b280f..00000000
--- a/keyserver/keyserver_test.go
+++ /dev/null
@@ -1,29 +0,0 @@
-package keyserver
-
-import (
- "context"
- "testing"
-
- roomserver "github.com/matrix-org/dendrite/roomserver/api"
- "github.com/matrix-org/dendrite/test"
- "github.com/matrix-org/dendrite/test/testrig"
-)
-
-type mockKeyserverRoomserverAPI struct {
- leftUsers []string
-}
-
-func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
- res.LeftUsers = m.leftUsers
- return nil
-}
-
-// Merely tests that we can create an internal keyserver API
-func Test_NewInternalAPI(t *testing.T) {
- rsAPI := &mockKeyserverRoomserverAPI{}
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- base, closeBase := testrig.CreateBaseDendrite(t, dbType)
- defer closeBase()
- _ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
- })
-}
diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go
deleted file mode 100644
index f86c3417..00000000
--- a/keyserver/producers/keychange.go
+++ /dev/null
@@ -1,107 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package producers
-
-import (
- "context"
- "encoding/json"
-
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage"
- "github.com/matrix-org/dendrite/setup/jetstream"
- "github.com/nats-io/nats.go"
- "github.com/sirupsen/logrus"
-)
-
-// KeyChange produces key change events for the sync API and federation sender to consume
-type KeyChange struct {
- Topic string
- JetStream nats.JetStreamContext
- DB storage.Database
-}
-
-// ProduceKeyChanges creates new change events for each key
-func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
- userToDeviceCount := make(map[string]int)
- for _, key := range keys {
- id, err := p.DB.StoreKeyChange(context.Background(), key.UserID)
- if err != nil {
- return err
- }
- key.DeviceChangeID = id
- value, err := json.Marshal(key)
- if err != nil {
- return err
- }
-
- m := &nats.Msg{
- Subject: p.Topic,
- Header: nats.Header{},
- }
- m.Header.Set(jetstream.UserID, key.UserID)
- m.Data = value
-
- _, err = p.JetStream.PublishMsg(m)
- if err != nil {
- return err
- }
-
- userToDeviceCount[key.UserID]++
- }
- for userID, count := range userToDeviceCount {
- logrus.WithFields(logrus.Fields{
- "user_id": userID,
- "num_key_changes": count,
- }).Tracef("Produced to key change topic '%s'", p.Topic)
- }
- return nil
-}
-
-func (p *KeyChange) ProduceSigningKeyUpdate(key api.CrossSigningKeyUpdate) error {
- output := &api.DeviceMessage{
- Type: api.TypeCrossSigningUpdate,
- OutputCrossSigningKeyUpdate: &api.OutputCrossSigningKeyUpdate{
- CrossSigningKeyUpdate: key,
- },
- }
-
- id, err := p.DB.StoreKeyChange(context.Background(), key.UserID)
- if err != nil {
- return err
- }
- output.DeviceChangeID = id
-
- value, err := json.Marshal(output)
- if err != nil {
- return err
- }
-
- m := &nats.Msg{
- Subject: p.Topic,
- Header: nats.Header{},
- }
- m.Header.Set(jetstream.UserID, key.UserID)
- m.Data = value
-
- _, err = p.JetStream.PublishMsg(m)
- if err != nil {
- return err
- }
-
- logrus.WithFields(logrus.Fields{
- "user_id": key.UserID,
- }).Tracef("Produced to cross-signing update topic '%s'", p.Topic)
- return nil
-}
diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go
deleted file mode 100644
index c6a8f44c..00000000
--- a/keyserver/storage/interface.go
+++ /dev/null
@@ -1,93 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package storage
-
-import (
- "context"
- "encoding/json"
-
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-type Database interface {
- // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
- // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
- ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
-
- // StoreOneTimeKeys persists the given one-time keys.
- StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
-
- // OneTimeKeysCount returns a count of all OTKs for this device.
- OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
-
- // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
- DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
-
- // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
- // for this (user, device).
- // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
- // Returns an error if there was a problem storing the keys.
- StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
-
- // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
- // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
- // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
- StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
-
- // PrevIDsExists returns true if all prev IDs exist for this user.
- PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
-
- // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
- // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
- DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
-
- // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
- // cross-signing signatures relating to that device.
- DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error
-
- // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
- // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
- ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
-
- // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
- // `userID` is the the user who has changed their keys in some way.
- StoreKeyChange(ctx context.Context, userID string) (int64, error)
-
- // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
- // A to offset of types.OffsetNewest means no upper limit.
- // Returns the offset of the latest key change.
- KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
-
- // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
- // If no domains are given, all user IDs with stale device lists are returned.
- StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
-
- // MarkDeviceListStale sets the stale bit for this user to isStale.
- MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
-
- CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error)
- CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error)
- CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
-
- StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
- StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
-
- DeleteStaleDeviceLists(
- ctx context.Context,
- userIDs []string,
- ) error
-}
diff --git a/keyserver/storage/postgres/cross_signing_keys_table.go b/keyserver/storage/postgres/cross_signing_keys_table.go
deleted file mode 100644
index 1022157e..00000000
--- a/keyserver/storage/postgres/cross_signing_keys_table.go
+++ /dev/null
@@ -1,102 +0,0 @@
-// Copyright 2021 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package postgres
-
-import (
- "context"
- "database/sql"
- "fmt"
-
- "github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-var crossSigningKeysSchema = `
-CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
- user_id TEXT NOT NULL,
- key_type SMALLINT NOT NULL,
- key_data TEXT NOT NULL,
- PRIMARY KEY (user_id, key_type)
-);
-`
-
-const selectCrossSigningKeysForUserSQL = "" +
- "SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
- " WHERE user_id = $1"
-
-const upsertCrossSigningKeysForUserSQL = "" +
- "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
- " VALUES($1, $2, $3)" +
- " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3"
-
-type crossSigningKeysStatements struct {
- db *sql.DB
- selectCrossSigningKeysForUserStmt *sql.Stmt
- upsertCrossSigningKeysForUserStmt *sql.Stmt
-}
-
-func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
- s := &crossSigningKeysStatements{
- db: db,
- }
- _, err := db.Exec(crossSigningKeysSchema)
- if err != nil {
- return nil, err
- }
- return s, sqlutil.StatementList{
- {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
- {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
- }.Prepare(db)
-}
-
-func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
- ctx context.Context, txn *sql.Tx, userID string,
-) (r types.CrossSigningKeyMap, err error) {
- rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
- r = types.CrossSigningKeyMap{}
- for rows.Next() {
- var keyTypeInt int16
- var keyData gomatrixserverlib.Base64Bytes
- if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
- return nil, err
- }
- keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
- if !ok {
- return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
- }
- r[keyType] = keyData
- }
- return
-}
-
-func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
- ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
-) error {
- keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
- if !ok {
- return fmt.Errorf("unknown key purpose %q", keyType)
- }
- if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
- return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
- }
- return nil
-}
diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/keyserver/storage/postgres/cross_signing_sigs_table.go
deleted file mode 100644
index 4536b7d8..00000000
--- a/keyserver/storage/postgres/cross_signing_sigs_table.go
+++ /dev/null
@@ -1,131 +0,0 @@
-// Copyright 2021 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package postgres
-
-import (
- "context"
- "database/sql"
- "fmt"
-
- "github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-var crossSigningSigsSchema = `
-CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
- origin_user_id TEXT NOT NULL,
- origin_key_id TEXT NOT NULL,
- target_user_id TEXT NOT NULL,
- target_key_id TEXT NOT NULL,
- signature TEXT NOT NULL,
- PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
-);
-
-CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
-`
-
-const selectCrossSigningSigsForTargetSQL = "" +
- "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
- " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3"
-
-const upsertCrossSigningSigsForTargetSQL = "" +
- "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
- " VALUES($1, $2, $3, $4, $5)" +
- " ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5"
-
-const deleteCrossSigningSigsForTargetSQL = "" +
- "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
-
-type crossSigningSigsStatements struct {
- db *sql.DB
- selectCrossSigningSigsForTargetStmt *sql.Stmt
- upsertCrossSigningSigsForTargetStmt *sql.Stmt
- deleteCrossSigningSigsForTargetStmt *sql.Stmt
-}
-
-func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
- s := &crossSigningSigsStatements{
- db: db,
- }
- _, err := db.Exec(crossSigningSigsSchema)
- if err != nil {
- return nil, err
- }
-
- m := sqlutil.NewMigrator(db)
- m.AddMigrations(sqlutil.Migration{
- Version: "keyserver: cross signing signature indexes",
- Up: deltas.UpFixCrossSigningSignatureIndexes,
- })
- if err = m.Up(context.Background()); err != nil {
- return nil, err
- }
-
- return s, sqlutil.StatementList{
- {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
- {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
- {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
- }.Prepare(db)
-}
-
-func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
- ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
-) (r types.CrossSigningSigMap, err error) {
- rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID)
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed")
- r = types.CrossSigningSigMap{}
- for rows.Next() {
- var userID string
- var keyID gomatrixserverlib.KeyID
- var signature gomatrixserverlib.Base64Bytes
- if err := rows.Scan(&userID, &keyID, &signature); err != nil {
- return nil, err
- }
- if _, ok := r[userID]; !ok {
- r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
- }
- r[userID][keyID] = signature
- }
- return
-}
-
-func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
- ctx context.Context, txn *sql.Tx,
- originUserID string, originKeyID gomatrixserverlib.KeyID,
- targetUserID string, targetKeyID gomatrixserverlib.KeyID,
- signature gomatrixserverlib.Base64Bytes,
-) error {
- if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
- return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
- }
- return nil
-}
-
-func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
- ctx context.Context, txn *sql.Tx,
- targetUserID string, targetKeyID gomatrixserverlib.KeyID,
-) error {
- if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
- return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
- }
- return nil
-}
diff --git a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go b/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go
deleted file mode 100644
index 0cfe9e79..00000000
--- a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright 2022 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package deltas
-
-import (
- "context"
- "database/sql"
- "fmt"
-)
-
-func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
- // start counting from the last max offset, else 0. We need to do a count(*) first to see if there
- // even are entries in this table to know if we can query for log_offset. Without the count then
- // the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't
- // exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/
- var count int
- _ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count)
- if count > 0 {
- var maxOffset int64
- _ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset)
- if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil {
- return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err)
- }
- }
-
- _, err := tx.ExecContext(ctx, `
- -- make the new table
- DROP TABLE IF EXISTS keyserver_key_changes;
- CREATE TABLE IF NOT EXISTS keyserver_key_changes (
- change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
- user_id TEXT NOT NULL,
- CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
- );
- `)
- if err != nil {
- return fmt.Errorf("failed to execute upgrade: %w", err)
- }
- return nil
-}
-
-func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
- _, err := tx.ExecContext(ctx, `
- -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
- DROP SEQUENCE IF EXISTS keyserver_key_changes_seq;
- DROP TABLE IF EXISTS keyserver_key_changes;
- CREATE TABLE IF NOT EXISTS keyserver_key_changes (
- partition BIGINT NOT NULL,
- log_offset BIGINT NOT NULL,
- user_id TEXT NOT NULL,
- CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
- );
- `)
- if err != nil {
- return fmt.Errorf("failed to execute downgrade: %w", err)
- }
- return nil
-}
diff --git a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go
deleted file mode 100644
index 1a3d4fee..00000000
--- a/keyserver/storage/postgres/deltas/2022042612000000_xsigning_idx.go
+++ /dev/null
@@ -1,47 +0,0 @@
-// Copyright 2022 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package deltas
-
-import (
- "context"
- "database/sql"
- "fmt"
-)
-
-func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
- _, err := tx.ExecContext(ctx, `
- ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
- ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id);
-
- CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
- `)
- if err != nil {
- return fmt.Errorf("failed to execute upgrade: %w", err)
- }
- return nil
-}
-
-func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
- _, err := tx.ExecContext(ctx, `
- ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
- ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id);
-
- DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
- `)
- if err != nil {
- return fmt.Errorf("failed to execute downgrade: %w", err)
- }
- return nil
-}
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go
deleted file mode 100644
index 2aa11c52..00000000
--- a/keyserver/storage/postgres/device_keys_table.go
+++ /dev/null
@@ -1,228 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package postgres
-
-import (
- "context"
- "database/sql"
- "time"
-
- "github.com/lib/pq"
-
- "github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
-)
-
-var deviceKeysSchema = `
--- Stores device keys for users
-CREATE TABLE IF NOT EXISTS keyserver_device_keys (
- user_id TEXT NOT NULL,
- device_id TEXT NOT NULL,
- ts_added_secs BIGINT NOT NULL,
- key_json TEXT NOT NULL,
- -- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
- -- This means we do not store an unbounded append-only log of device keys, which is not actually
- -- required in the spec because in the event of a missed update the server fetches the entire
- -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
- stream_id BIGINT NOT NULL,
- display_name TEXT,
- -- Clobber based on tuple of user/device.
- CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
-);
-`
-
-const upsertDeviceKeysSQL = "" +
- "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
- " VALUES ($1, $2, $3, $4, $5, $6)" +
- " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
- " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
-
-const selectDeviceKeysSQL = "" +
- "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
-
-const selectBatchDeviceKeysSQL = "" +
- "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
-
-const selectBatchDeviceKeysWithEmptiesSQL = "" +
- "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
-
-const selectMaxStreamForUserSQL = "" +
- "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
-
-const countStreamIDsForUserSQL = "" +
- "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
-
-const deleteDeviceKeysSQL = "" +
- "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
-
-const deleteAllDeviceKeysSQL = "" +
- "DELETE FROM keyserver_device_keys WHERE user_id=$1"
-
-type deviceKeysStatements struct {
- db *sql.DB
- upsertDeviceKeysStmt *sql.Stmt
- selectDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
- selectMaxStreamForUserStmt *sql.Stmt
- countStreamIDsForUserStmt *sql.Stmt
- deleteDeviceKeysStmt *sql.Stmt
- deleteAllDeviceKeysStmt *sql.Stmt
-}
-
-func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
- s := &deviceKeysStatements{
- db: db,
- }
- _, err := db.Exec(deviceKeysSchema)
- if err != nil {
- return nil, err
- }
- if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
- return nil, err
- }
- if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
- return nil, err
- }
- if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
- return nil, err
- }
- if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
- return nil, err
- }
- return s, nil
-}
-
-func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
- for i, key := range keys {
- var keyJSONStr string
- var streamID int64
- var displayName sql.NullString
- err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
- if err != nil && err != sql.ErrNoRows {
- return err
- }
- // this will be '' when there is no device
- keys[i].Type = api.TypeDeviceKeyUpdate
- keys[i].KeyJSON = []byte(keyJSONStr)
- keys[i].StreamID = streamID
- if displayName.Valid {
- keys[i].DisplayName = displayName.String
- }
- }
- return nil
-}
-
-func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
- // nullable if there are no results
- var nullStream sql.NullInt64
- err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
- if err == sql.ErrNoRows {
- err = nil
- }
- if nullStream.Valid {
- streamID = nullStream.Int64
- }
- return
-}
-
-func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
- // nullable if there are no results
- var count sql.NullInt32
- err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count)
- if err != nil {
- return 0, err
- }
- if count.Valid {
- return int(count.Int32), nil
- }
- return 0, nil
-}
-
-func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
- for _, key := range keys {
- now := time.Now().Unix()
- _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
- ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
- )
- if err != nil {
- return err
- }
- }
- return nil
-}
-
-func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
- _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
- return err
-}
-
-func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
- _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
- return err
-}
-
-func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
- var stmt *sql.Stmt
- if includeEmpty {
- stmt = s.selectBatchDeviceKeysWithEmptiesStmt
- } else {
- stmt = s.selectBatchDeviceKeysStmt
- }
- rows, err := stmt.QueryContext(ctx, userID)
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
- deviceIDMap := make(map[string]bool)
- for _, d := range deviceIDs {
- deviceIDMap[d] = true
- }
- var result []api.DeviceMessage
- var displayName sql.NullString
- for rows.Next() {
- dk := api.DeviceMessage{
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- UserID: userID,
- },
- }
- if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
- return nil, err
- }
- if displayName.Valid {
- dk.DisplayName = displayName.String
- }
- // include the key if we want all keys (no device) or it was asked
- if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
- result = append(result, dk)
- }
- }
- return result, rows.Err()
-}
diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go
deleted file mode 100644
index c0e3429c..00000000
--- a/keyserver/storage/postgres/key_changes_table.go
+++ /dev/null
@@ -1,134 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package postgres
-
-import (
- "context"
- "database/sql"
- "errors"
- "fmt"
-
- "github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
-)
-
-var keyChangesSchema = `
--- Stores key change information about users. Used to determine when to send updated device lists to clients.
-CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq;
-CREATE TABLE IF NOT EXISTS keyserver_key_changes (
- change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
- user_id TEXT NOT NULL,
- CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
-);
-`
-
-// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
-// have changed, hence we can just keep bumping the change ID for this user.
-const upsertKeyChangeSQL = "" +
- "INSERT INTO keyserver_key_changes (user_id)" +
- " VALUES ($1)" +
- " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" +
- " DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" +
- " RETURNING change_id"
-
-const selectKeyChangesSQL = "" +
- "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
-
-type keyChangesStatements struct {
- db *sql.DB
- upsertKeyChangeStmt *sql.Stmt
- selectKeyChangesStmt *sql.Stmt
-}
-
-func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
- s := &keyChangesStatements{
- db: db,
- }
- _, err := db.Exec(keyChangesSchema)
- if err != nil {
- return s, err
- }
-
- if err = executeMigration(context.Background(), db); err != nil {
- return nil, err
- }
- return s, nil
-}
-
-func executeMigration(ctx context.Context, db *sql.DB) error {
- // TODO: Remove when we are sure we are not having goose artefacts in the db
- // This forces an error, which indicates the migration is already applied, since the
- // column partition was removed from the table
- migrationName := "keyserver: refactor key changes"
-
- var cName string
- err := db.QueryRowContext(ctx, "select column_name from information_schema.columns where table_name = 'keyserver_key_changes' AND column_name = 'partition'").Scan(&cName)
- if err != nil {
- if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed
- if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil {
- return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
- }
- return nil
- }
- return err
- }
- m := sqlutil.NewMigrator(db)
- m.AddMigrations(sqlutil.Migration{
- Version: migrationName,
- Up: deltas.UpRefactorKeyChanges,
- })
-
- return m.Up(ctx)
-}
-
-func (s *keyChangesStatements) Prepare() (err error) {
- if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil {
- return err
- }
- if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil {
- return err
- }
- return nil
-}
-
-func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
- err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
- return
-}
-
-func (s *keyChangesStatements) SelectKeyChanges(
- ctx context.Context, fromOffset, toOffset int64,
-) (userIDs []string, latestOffset int64, err error) {
- latestOffset = fromOffset
- rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
- if err != nil {
- return nil, 0, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
- for rows.Next() {
- var userID string
- var offset int64
- if err := rows.Scan(&userID, &offset); err != nil {
- return nil, 0, err
- }
- if offset > latestOffset {
- latestOffset = offset
- }
- userIDs = append(userIDs, userID)
- }
- return
-}
diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go
deleted file mode 100644
index 2117efca..00000000
--- a/keyserver/storage/postgres/one_time_keys_table.go
+++ /dev/null
@@ -1,205 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package postgres
-
-import (
- "context"
- "database/sql"
- "encoding/json"
- "time"
-
- "github.com/lib/pq"
- "github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
-)
-
-var oneTimeKeysSchema = `
--- Stores one-time public keys for users
-CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
- user_id TEXT NOT NULL,
- device_id TEXT NOT NULL,
- key_id TEXT NOT NULL,
- algorithm TEXT NOT NULL,
- ts_added_secs BIGINT NOT NULL,
- key_json TEXT NOT NULL,
- -- Clobber based on 4-uple of user/device/key/algorithm.
- CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm)
-);
-
-CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id);
-`
-
-const upsertKeysSQL = "" +
- "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
- " VALUES ($1, $2, $3, $4, $5, $6)" +
- " ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" +
- " DO UPDATE SET key_json = $6"
-
-const selectKeysSQL = "" +
- "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
-
-const selectKeysCountSQL = "" +
- "SELECT algorithm, COUNT(key_id) FROM " +
- " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" +
- " x GROUP BY algorithm"
-
-const deleteOneTimeKeySQL = "" +
- "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
-
-const selectKeyByAlgorithmSQL = "" +
- "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
-
-const deleteOneTimeKeysSQL = "" +
- "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
-
-type oneTimeKeysStatements struct {
- db *sql.DB
- upsertKeysStmt *sql.Stmt
- selectKeysStmt *sql.Stmt
- selectKeysCountStmt *sql.Stmt
- selectKeyByAlgorithmStmt *sql.Stmt
- deleteOneTimeKeyStmt *sql.Stmt
- deleteOneTimeKeysStmt *sql.Stmt
-}
-
-func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
- s := &oneTimeKeysStatements{
- db: db,
- }
- _, err := db.Exec(oneTimeKeysSchema)
- if err != nil {
- return nil, err
- }
- if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil {
- return nil, err
- }
- if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
- return nil, err
- }
- if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
- return nil, err
- }
- if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
- return nil, err
- }
- if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
- return nil, err
- }
- if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil {
- return nil, err
- }
- return s, nil
-}
-
-func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
- rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms))
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
-
- result := make(map[string]json.RawMessage)
- var (
- algorithmWithID string
- keyJSONStr string
- )
- for rows.Next() {
- if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
- return nil, err
- }
- result[algorithmWithID] = json.RawMessage(keyJSONStr)
- }
- return result, rows.Err()
-}
-
-func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
- counts := &api.OneTimeKeysCount{
- DeviceID: deviceID,
- UserID: userID,
- KeyCount: make(map[string]int),
- }
- rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
- for rows.Next() {
- var algorithm string
- var count int
- if err = rows.Scan(&algorithm, &count); err != nil {
- return nil, err
- }
- counts.KeyCount[algorithm] = count
- }
- return counts, nil
-}
-
-func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
- now := time.Now().Unix()
- counts := &api.OneTimeKeysCount{
- DeviceID: keys.DeviceID,
- UserID: keys.UserID,
- KeyCount: make(map[string]int),
- }
- for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
- algo, keyID := keys.Split(keyIDWithAlgo)
- _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
- ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
- )
- if err != nil {
- return nil, err
- }
- }
- rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
- for rows.Next() {
- var algorithm string
- var count int
- if err = rows.Scan(&algorithm, &count); err != nil {
- return nil, err
- }
- counts.KeyCount[algorithm] = count
- }
-
- return counts, rows.Err()
-}
-
-func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
- ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
-) (map[string]json.RawMessage, error) {
- var keyID string
- var keyJSON string
- err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, nil
- }
- return nil, err
- }
- _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
- return map[string]json.RawMessage{
- algorithm + ":" + keyID: json.RawMessage(keyJSON),
- }, err
-}
-
-func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
- _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID)
- return err
-}
diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go
deleted file mode 100644
index 248ddfb4..00000000
--- a/keyserver/storage/postgres/stale_device_lists.go
+++ /dev/null
@@ -1,131 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package postgres
-
-import (
- "context"
- "database/sql"
- "time"
-
- "github.com/lib/pq"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
-
- "github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-var staleDeviceListsSchema = `
--- Stores whether a user's device lists are stale or not.
-CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
- user_id TEXT PRIMARY KEY NOT NULL,
- domain TEXT NOT NULL,
- is_stale BOOLEAN NOT NULL,
- ts_added_secs BIGINT NOT NULL
-);
-
-CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
-`
-
-const upsertStaleDeviceListSQL = "" +
- "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
- " VALUES ($1, $2, $3, $4)" +
- " ON CONFLICT (user_id)" +
- " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
-
-const selectStaleDeviceListsWithDomainsSQL = "" +
- "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
-
-const selectStaleDeviceListsSQL = "" +
- "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
-
-const deleteStaleDevicesSQL = "" +
- "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
-
-type staleDeviceListsStatements struct {
- upsertStaleDeviceListStmt *sql.Stmt
- selectStaleDeviceListsWithDomainsStmt *sql.Stmt
- selectStaleDeviceListsStmt *sql.Stmt
- deleteStaleDeviceListsStmt *sql.Stmt
-}
-
-func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
- s := &staleDeviceListsStatements{}
- _, err := db.Exec(staleDeviceListsSchema)
- if err != nil {
- return nil, err
- }
- return s, sqlutil.StatementList{
- {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
- {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
- {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
- {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL},
- }.Prepare(db)
-}
-
-func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
- _, domain, err := gomatrixserverlib.SplitID('@', userID)
- if err != nil {
- return err
- }
- _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
- return err
-}
-
-func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
- // we only query for 1 domain or all domains so optimise for those use cases
- if len(domains) == 0 {
- rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
- if err != nil {
- return nil, err
- }
- return rowsToUserIDs(ctx, rows)
- }
- var result []string
- for _, domain := range domains {
- rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
- if err != nil {
- return nil, err
- }
- userIDs, err := rowsToUserIDs(ctx, rows)
- if err != nil {
- return nil, err
- }
- result = append(result, userIDs...)
- }
- return result, nil
-}
-
-// DeleteStaleDeviceLists removes users from stale device lists
-func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
- ctx context.Context, txn *sql.Tx, userIDs []string,
-) error {
- stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt)
- _, err := stmt.ExecContext(ctx, pq.Array(userIDs))
- return err
-}
-
-func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
- defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
- for rows.Next() {
- var userID string
- if err := rows.Scan(&userID); err != nil {
- return nil, err
- }
- result = append(result, userID)
- }
- return result, rows.Err()
-}
diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go
deleted file mode 100644
index 35e63055..00000000
--- a/keyserver/storage/postgres/storage.go
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package postgres
-
-import (
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/shared"
- "github.com/matrix-org/dendrite/setup/base"
- "github.com/matrix-org/dendrite/setup/config"
-)
-
-// NewDatabase creates a new sync server database
-func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) {
- var err error
- db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter())
- if err != nil {
- return nil, err
- }
- otk, err := NewPostgresOneTimeKeysTable(db)
- if err != nil {
- return nil, err
- }
- dk, err := NewPostgresDeviceKeysTable(db)
- if err != nil {
- return nil, err
- }
- kc, err := NewPostgresKeyChangesTable(db)
- if err != nil {
- return nil, err
- }
- sdl, err := NewPostgresStaleDeviceListsTable(db)
- if err != nil {
- return nil, err
- }
- csk, err := NewPostgresCrossSigningKeysTable(db)
- if err != nil {
- return nil, err
- }
- css, err := NewPostgresCrossSigningSigsTable(db)
- if err != nil {
- return nil, err
- }
- if err = kc.Prepare(); err != nil {
- return nil, err
- }
- d := &shared.Database{
- DB: db,
- Writer: writer,
- OneTimeKeysTable: otk,
- DeviceKeysTable: dk,
- KeyChangesTable: kc,
- StaleDeviceListsTable: sdl,
- CrossSigningKeysTable: csk,
- CrossSigningSigsTable: css,
- }
- return d, nil
-}
diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go
deleted file mode 100644
index 54dd6ddc..00000000
--- a/keyserver/storage/shared/storage.go
+++ /dev/null
@@ -1,261 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package shared
-
-import (
- "context"
- "database/sql"
- "encoding/json"
- "fmt"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-type Database struct {
- DB *sql.DB
- Writer sqlutil.Writer
- OneTimeKeysTable tables.OneTimeKeys
- DeviceKeysTable tables.DeviceKeys
- KeyChangesTable tables.KeyChanges
- StaleDeviceListsTable tables.StaleDeviceLists
- CrossSigningKeysTable tables.CrossSigningKeys
- CrossSigningSigsTable tables.CrossSigningSigs
-}
-
-func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
- return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
-}
-
-func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
- _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys)
- return err
- })
- return
-}
-
-func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
- return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
-}
-
-func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
- return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
-}
-
-func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
- count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
- if err != nil {
- return false, err
- }
- return count == len(prevIDs), nil
-}
-
-func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- for _, userID := range clearUserIDs {
- err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
- if err != nil {
- return err
- }
- }
- return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
- })
-}
-
-func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
- // work out the latest stream IDs for each user
- userIDToStreamID := make(map[string]int64)
- for _, k := range keys {
- userIDToStreamID[k.UserID] = 0
- }
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- for userID := range userIDToStreamID {
- streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
- if err != nil {
- return err
- }
- userIDToStreamID[userID] = streamID
- }
- // set the stream IDs for each key
- for i := range keys {
- k := keys[i]
- userIDToStreamID[k.UserID]++ // start stream from 1
- k.StreamID = userIDToStreamID[k.UserID]
- keys[i] = k
- }
- return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
- })
-}
-
-func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
- return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
-}
-
-func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
- var result []api.OneTimeKeys
- err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- for userID, deviceToAlgo := range userToDeviceToAlgorithm {
- for deviceID, algo := range deviceToAlgo {
- keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
- if err != nil {
- return err
- }
- if keyJSON != nil {
- result = append(result, api.OneTimeKeys{
- UserID: userID,
- DeviceID: deviceID,
- KeyJSON: keyJSON,
- })
- }
- }
- }
- return nil
- })
- return result, err
-}
-
-func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
- err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
- id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
- return err
- })
- return
-}
-
-func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
- return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
-}
-
-// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
-// If no domains are given, all user IDs with stale device lists are returned.
-func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
- return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
-}
-
-// MarkDeviceListStale sets the stale bit for this user to isStale.
-func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
- return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
- return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
- })
-}
-
-// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
-// cross-signing signatures relating to that device.
-func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- for _, deviceID := range deviceIDs {
- if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows {
- return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err)
- }
- if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
- return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err)
- }
- if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
- return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err)
- }
- }
- return nil
- })
-}
-
-// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
-func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) {
- keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
- if err != nil {
- return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err)
- }
- results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
- for purpose, key := range keyMap {
- keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode())
- result := gomatrixserverlib.CrossSigningKey{
- UserID: userID,
- Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose},
- Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
- keyID: key,
- },
- }
- sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID)
- if err != nil {
- continue
- }
- for sigUserID, forSigUserID := range sigMap {
- if userID != sigUserID {
- continue
- }
- if result.Signatures == nil {
- result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
- }
- if _, ok := result.Signatures[sigUserID]; !ok {
- result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
- }
- for sigKeyID, sigBytes := range forSigUserID {
- result.Signatures[sigUserID][sigKeyID] = sigBytes
- }
- }
- results[purpose] = result
- }
- return results, nil
-}
-
-// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
-func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) {
- return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
-}
-
-// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
-func (d *Database) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
- return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID)
-}
-
-// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
-func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- for keyType, keyData := range keyMap {
- if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
- return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
- }
- }
- return nil
- })
-}
-
-// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
-func (d *Database) StoreCrossSigningSigsForTarget(
- ctx context.Context,
- originUserID string, originKeyID gomatrixserverlib.KeyID,
- targetUserID string, targetKeyID gomatrixserverlib.KeyID,
- signature gomatrixserverlib.Base64Bytes,
-) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
- return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err)
- }
- return nil
- })
-}
-
-// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
-func (d *Database) DeleteStaleDeviceLists(
- ctx context.Context,
- userIDs []string,
-) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
- })
-}
diff --git a/keyserver/storage/sqlite3/cross_signing_keys_table.go b/keyserver/storage/sqlite3/cross_signing_keys_table.go
deleted file mode 100644
index e103d988..00000000
--- a/keyserver/storage/sqlite3/cross_signing_keys_table.go
+++ /dev/null
@@ -1,101 +0,0 @@
-// Copyright 2021 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package sqlite3
-
-import (
- "context"
- "database/sql"
- "fmt"
-
- "github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-var crossSigningKeysSchema = `
-CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
- user_id TEXT NOT NULL,
- key_type INTEGER NOT NULL,
- key_data TEXT NOT NULL,
- PRIMARY KEY (user_id, key_type)
-);
-`
-
-const selectCrossSigningKeysForUserSQL = "" +
- "SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
- " WHERE user_id = $1"
-
-const upsertCrossSigningKeysForUserSQL = "" +
- "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
- " VALUES($1, $2, $3)"
-
-type crossSigningKeysStatements struct {
- db *sql.DB
- selectCrossSigningKeysForUserStmt *sql.Stmt
- upsertCrossSigningKeysForUserStmt *sql.Stmt
-}
-
-func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
- s := &crossSigningKeysStatements{
- db: db,
- }
- _, err := db.Exec(crossSigningKeysSchema)
- if err != nil {
- return nil, err
- }
- return s, sqlutil.StatementList{
- {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
- {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
- }.Prepare(db)
-}
-
-func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
- ctx context.Context, txn *sql.Tx, userID string,
-) (r types.CrossSigningKeyMap, err error) {
- rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
- r = types.CrossSigningKeyMap{}
- for rows.Next() {
- var keyTypeInt int16
- var keyData gomatrixserverlib.Base64Bytes
- if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
- return nil, err
- }
- keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
- if !ok {
- return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
- }
- r[keyType] = keyData
- }
- return
-}
-
-func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
- ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
-) error {
- keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
- if !ok {
- return fmt.Errorf("unknown key purpose %q", keyType)
- }
- if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
- return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
- }
- return nil
-}
diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/keyserver/storage/sqlite3/cross_signing_sigs_table.go
deleted file mode 100644
index 7a153e8f..00000000
--- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go
+++ /dev/null
@@ -1,129 +0,0 @@
-// Copyright 2021 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package sqlite3
-
-import (
- "context"
- "database/sql"
- "fmt"
-
- "github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-var crossSigningSigsSchema = `
-CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
- origin_user_id TEXT NOT NULL,
- origin_key_id TEXT NOT NULL,
- target_user_id TEXT NOT NULL,
- target_key_id TEXT NOT NULL,
- signature TEXT NOT NULL,
- PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
-);
-
-CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
-`
-
-const selectCrossSigningSigsForTargetSQL = "" +
- "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
- " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4"
-
-const upsertCrossSigningSigsForTargetSQL = "" +
- "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
- " VALUES($1, $2, $3, $4, $5)"
-
-const deleteCrossSigningSigsForTargetSQL = "" +
- "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
-
-type crossSigningSigsStatements struct {
- db *sql.DB
- selectCrossSigningSigsForTargetStmt *sql.Stmt
- upsertCrossSigningSigsForTargetStmt *sql.Stmt
- deleteCrossSigningSigsForTargetStmt *sql.Stmt
-}
-
-func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
- s := &crossSigningSigsStatements{
- db: db,
- }
- _, err := db.Exec(crossSigningSigsSchema)
- if err != nil {
- return nil, err
- }
- m := sqlutil.NewMigrator(db)
- m.AddMigrations(sqlutil.Migration{
- Version: "keyserver: cross signing signature indexes",
- Up: deltas.UpFixCrossSigningSignatureIndexes,
- })
- if err = m.Up(context.Background()); err != nil {
- return nil, err
- }
-
- return s, sqlutil.StatementList{
- {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
- {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
- {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
- }.Prepare(db)
-}
-
-func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
- ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
-) (r types.CrossSigningSigMap, err error) {
- rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID)
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForOriginTargetStmt: rows.close() failed")
- r = types.CrossSigningSigMap{}
- for rows.Next() {
- var userID string
- var keyID gomatrixserverlib.KeyID
- var signature gomatrixserverlib.Base64Bytes
- if err := rows.Scan(&userID, &keyID, &signature); err != nil {
- return nil, err
- }
- if _, ok := r[userID]; !ok {
- r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
- }
- r[userID][keyID] = signature
- }
- return
-}
-
-func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
- ctx context.Context, txn *sql.Tx,
- originUserID string, originKeyID gomatrixserverlib.KeyID,
- targetUserID string, targetKeyID gomatrixserverlib.KeyID,
- signature gomatrixserverlib.Base64Bytes,
-) error {
- if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
- return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
- }
- return nil
-}
-
-func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
- ctx context.Context, txn *sql.Tx,
- targetUserID string, targetKeyID gomatrixserverlib.KeyID,
-) error {
- if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
- return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
- }
- return nil
-}
diff --git a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go b/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go
deleted file mode 100644
index cd0f19df..00000000
--- a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go
+++ /dev/null
@@ -1,66 +0,0 @@
-// Copyright 2022 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package deltas
-
-import (
- "context"
- "database/sql"
- "fmt"
-)
-
-func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
- // start counting from the last max offset, else 0.
- var maxOffset int64
- var userID string
- _ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset)
-
- _, err := tx.ExecContext(ctx, `
- -- make the new table
- DROP TABLE IF EXISTS keyserver_key_changes;
- CREATE TABLE IF NOT EXISTS keyserver_key_changes (
- change_id INTEGER PRIMARY KEY AUTOINCREMENT,
- -- The key owner
- user_id TEXT NOT NULL,
- UNIQUE (user_id)
- );
- `)
- if err != nil {
- return fmt.Errorf("failed to execute upgrade: %w", err)
- }
- // to start counting from maxOffset, insert a row with that value
- if userID != "" {
- _, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID)
- return err
- }
- return nil
-}
-
-func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
- _, err := tx.ExecContext(ctx, `
- -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
- DROP TABLE IF EXISTS keyserver_key_changes;
- CREATE TABLE IF NOT EXISTS keyserver_key_changes (
- partition BIGINT NOT NULL,
- offset BIGINT NOT NULL,
- -- The key owner
- user_id TEXT NOT NULL,
- UNIQUE (partition, offset)
- );
- `)
- if err != nil {
- return fmt.Errorf("failed to execute downgrade: %w", err)
- }
- return nil
-}
diff --git a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go
deleted file mode 100644
index d4e38dea..00000000
--- a/keyserver/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright 2022 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package deltas
-
-import (
- "context"
- "database/sql"
- "fmt"
-)
-
-func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
- _, err := tx.ExecContext(ctx, `
- CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
- origin_user_id TEXT NOT NULL,
- origin_key_id TEXT NOT NULL,
- target_user_id TEXT NOT NULL,
- target_key_id TEXT NOT NULL,
- signature TEXT NOT NULL,
- PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
- );
-
- INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
- SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
-
- DROP TABLE keyserver_cross_signing_sigs;
- ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
-
- CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
- `)
- if err != nil {
- return fmt.Errorf("failed to execute upgrade: %w", err)
- }
- return nil
-}
-
-func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
- _, err := tx.ExecContext(ctx, `
- CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
- origin_user_id TEXT NOT NULL,
- origin_key_id TEXT NOT NULL,
- target_user_id TEXT NOT NULL,
- target_key_id TEXT NOT NULL,
- signature TEXT NOT NULL,
- PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
- );
-
- INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
- SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
-
- DROP TABLE keyserver_cross_signing_sigs;
- ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
-
- DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
- `)
- if err != nil {
- return fmt.Errorf("failed to execute downgrade: %w", err)
- }
- return nil
-}
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
deleted file mode 100644
index 73768da5..00000000
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ /dev/null
@@ -1,225 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package sqlite3
-
-import (
- "context"
- "database/sql"
- "strings"
- "time"
-
- "github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
-)
-
-var deviceKeysSchema = `
--- Stores device keys for users
-CREATE TABLE IF NOT EXISTS keyserver_device_keys (
- user_id TEXT NOT NULL,
- device_id TEXT NOT NULL,
- ts_added_secs BIGINT NOT NULL,
- key_json TEXT NOT NULL,
- stream_id BIGINT NOT NULL,
- display_name TEXT,
- -- Clobber based on tuple of user/device.
- UNIQUE (user_id, device_id)
-);
-`
-
-const upsertDeviceKeysSQL = "" +
- "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
- " VALUES ($1, $2, $3, $4, $5, $6)" +
- " ON CONFLICT (user_id, device_id)" +
- " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
-
-const selectDeviceKeysSQL = "" +
- "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
-
-const selectBatchDeviceKeysSQL = "" +
- "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
-
-const selectBatchDeviceKeysWithEmptiesSQL = "" +
- "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
-
-const selectMaxStreamForUserSQL = "" +
- "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
-
-const countStreamIDsForUserSQL = "" +
- "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
-
-const deleteDeviceKeysSQL = "" +
- "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
-
-const deleteAllDeviceKeysSQL = "" +
- "DELETE FROM keyserver_device_keys WHERE user_id=$1"
-
-type deviceKeysStatements struct {
- db *sql.DB
- upsertDeviceKeysStmt *sql.Stmt
- selectDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
- selectMaxStreamForUserStmt *sql.Stmt
- deleteDeviceKeysStmt *sql.Stmt
- deleteAllDeviceKeysStmt *sql.Stmt
-}
-
-func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
- s := &deviceKeysStatements{
- db: db,
- }
- _, err := db.Exec(deviceKeysSchema)
- if err != nil {
- return nil, err
- }
- if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
- return nil, err
- }
- if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
- return nil, err
- }
- if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
- return nil, err
- }
- return s, nil
-}
-
-func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
- _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
- return err
-}
-
-func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
- _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
- return err
-}
-
-func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
- deviceIDMap := make(map[string]bool)
- for _, d := range deviceIDs {
- deviceIDMap[d] = true
- }
- var stmt *sql.Stmt
- if includeEmpty {
- stmt = s.selectBatchDeviceKeysWithEmptiesStmt
- } else {
- stmt = s.selectBatchDeviceKeysStmt
- }
- rows, err := stmt.QueryContext(ctx, userID)
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
- var result []api.DeviceMessage
- var displayName sql.NullString
- for rows.Next() {
- dk := api.DeviceMessage{
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- UserID: userID,
- },
- }
- if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
- return nil, err
- }
- if displayName.Valid {
- dk.DisplayName = displayName.String
- }
- // include the key if we want all keys (no device) or it was asked
- if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
- result = append(result, dk)
- }
- }
- return result, rows.Err()
-}
-
-func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
- for i, key := range keys {
- var keyJSONStr string
- var streamID int64
- var displayName sql.NullString
- err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
- if err != nil && err != sql.ErrNoRows {
- return err
- }
- // this will be '' when there is no device
- keys[i].Type = api.TypeDeviceKeyUpdate
- keys[i].KeyJSON = []byte(keyJSONStr)
- keys[i].StreamID = streamID
- if displayName.Valid {
- keys[i].DisplayName = displayName.String
- }
- }
- return nil
-}
-
-func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
- // nullable if there are no results
- var nullStream sql.NullInt64
- err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
- if err == sql.ErrNoRows {
- err = nil
- }
- if nullStream.Valid {
- streamID = nullStream.Int64
- }
- return
-}
-
-func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
- iStreamIDs := make([]interface{}, len(streamIDs)+1)
- iStreamIDs[0] = userID
- for i := range streamIDs {
- iStreamIDs[i+1] = streamIDs[i]
- }
- query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
- // nullable if there are no results
- var count sql.NullInt64
- err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
- if err != nil {
- return 0, err
- }
- if count.Valid {
- return int(count.Int64), nil
- }
- return 0, nil
-}
-
-func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
- for _, key := range keys {
- now := time.Now().Unix()
- _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
- ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
- )
- if err != nil {
- return err
- }
- }
- return nil
-}
diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go
deleted file mode 100644
index 0c844d67..00000000
--- a/keyserver/storage/sqlite3/key_changes_table.go
+++ /dev/null
@@ -1,132 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package sqlite3
-
-import (
- "context"
- "database/sql"
- "errors"
- "fmt"
-
- "github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
-)
-
-var keyChangesSchema = `
--- Stores key change information about users. Used to determine when to send updated device lists to clients.
-CREATE TABLE IF NOT EXISTS keyserver_key_changes (
- change_id INTEGER PRIMARY KEY AUTOINCREMENT,
- -- The key owner
- user_id TEXT NOT NULL,
- UNIQUE (user_id)
-);
-`
-
-// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
-// have changed, hence we can just keep bumping the change ID for this user.
-const upsertKeyChangeSQL = "" +
- "INSERT OR REPLACE INTO keyserver_key_changes (user_id)" +
- " VALUES ($1)" +
- " RETURNING change_id"
-
-const selectKeyChangesSQL = "" +
- "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
-
-type keyChangesStatements struct {
- db *sql.DB
- upsertKeyChangeStmt *sql.Stmt
- selectKeyChangesStmt *sql.Stmt
-}
-
-func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
- s := &keyChangesStatements{
- db: db,
- }
- _, err := db.Exec(keyChangesSchema)
- if err != nil {
- return s, err
- }
-
- if err = executeMigration(context.Background(), db); err != nil {
- return nil, err
- }
-
- return s, nil
-}
-
-func executeMigration(ctx context.Context, db *sql.DB) error {
- // TODO: Remove when we are sure we are not having goose artefacts in the db
- // This forces an error, which indicates the migration is already applied, since the
- // column partition was removed from the table
- migrationName := "keyserver: refactor key changes"
-
- var cName string
- err := db.QueryRowContext(ctx, `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'keyserver_key_changes' AND p.name = 'partition'`).Scan(&cName)
- if err != nil {
- if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed
- if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil {
- return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
- }
- return nil
- }
- return err
- }
- m := sqlutil.NewMigrator(db)
- m.AddMigrations(sqlutil.Migration{
- Version: migrationName,
- Up: deltas.UpRefactorKeyChanges,
- })
- return m.Up(ctx)
-}
-
-func (s *keyChangesStatements) Prepare() (err error) {
- if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil {
- return err
- }
- if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil {
- return err
- }
- return nil
-}
-
-func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
- err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
- return
-}
-
-func (s *keyChangesStatements) SelectKeyChanges(
- ctx context.Context, fromOffset, toOffset int64,
-) (userIDs []string, latestOffset int64, err error) {
- latestOffset = fromOffset
- rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
- if err != nil {
- return nil, 0, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
- for rows.Next() {
- var userID string
- var offset int64
- if err := rows.Scan(&userID, &offset); err != nil {
- return nil, 0, err
- }
- if offset > latestOffset {
- latestOffset = offset
- }
- userIDs = append(userIDs, userID)
- }
- return
-}
diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go
deleted file mode 100644
index 7a923d0e..00000000
--- a/keyserver/storage/sqlite3/one_time_keys_table.go
+++ /dev/null
@@ -1,219 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package sqlite3
-
-import (
- "context"
- "database/sql"
- "encoding/json"
- "time"
-
- "github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
-)
-
-var oneTimeKeysSchema = `
--- Stores one-time public keys for users
-CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
- user_id TEXT NOT NULL,
- device_id TEXT NOT NULL,
- key_id TEXT NOT NULL,
- algorithm TEXT NOT NULL,
- ts_added_secs BIGINT NOT NULL,
- key_json TEXT NOT NULL,
- -- Clobber based on 4-uple of user/device/key/algorithm.
- UNIQUE (user_id, device_id, key_id, algorithm)
-);
-
-CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id);
-`
-
-const upsertKeysSQL = "" +
- "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
- " VALUES ($1, $2, $3, $4, $5, $6)" +
- " ON CONFLICT (user_id, device_id, key_id, algorithm)" +
- " DO UPDATE SET key_json = $6"
-
-const selectKeysSQL = "" +
- "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
-
-const selectKeysCountSQL = "" +
- "SELECT algorithm, COUNT(key_id) FROM " +
- " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" +
- " x GROUP BY algorithm"
-
-const deleteOneTimeKeySQL = "" +
- "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
-
-const selectKeyByAlgorithmSQL = "" +
- "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
-
-const deleteOneTimeKeysSQL = "" +
- "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
-
-type oneTimeKeysStatements struct {
- db *sql.DB
- upsertKeysStmt *sql.Stmt
- selectKeysStmt *sql.Stmt
- selectKeysCountStmt *sql.Stmt
- selectKeyByAlgorithmStmt *sql.Stmt
- deleteOneTimeKeyStmt *sql.Stmt
- deleteOneTimeKeysStmt *sql.Stmt
-}
-
-func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
- s := &oneTimeKeysStatements{
- db: db,
- }
- _, err := db.Exec(oneTimeKeysSchema)
- if err != nil {
- return nil, err
- }
- if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil {
- return nil, err
- }
- if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
- return nil, err
- }
- if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
- return nil, err
- }
- if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
- return nil, err
- }
- if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
- return nil, err
- }
- if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil {
- return nil, err
- }
- return s, nil
-}
-
-func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
- rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID)
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
-
- wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
- for _, ka := range keyIDsWithAlgorithms {
- wantSet[ka] = true
- }
-
- result := make(map[string]json.RawMessage)
- for rows.Next() {
- var keyID string
- var algorithm string
- var keyJSONStr string
- if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
- return nil, err
- }
- keyIDWithAlgo := algorithm + ":" + keyID
- if wantSet[keyIDWithAlgo] {
- result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
- }
- }
- return result, rows.Err()
-}
-
-func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
- counts := &api.OneTimeKeysCount{
- DeviceID: deviceID,
- UserID: userID,
- KeyCount: make(map[string]int),
- }
- rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
- for rows.Next() {
- var algorithm string
- var count int
- if err = rows.Scan(&algorithm, &count); err != nil {
- return nil, err
- }
- counts.KeyCount[algorithm] = count
- }
- return counts, nil
-}
-
-func (s *oneTimeKeysStatements) InsertOneTimeKeys(
- ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys,
-) (*api.OneTimeKeysCount, error) {
- now := time.Now().Unix()
- counts := &api.OneTimeKeysCount{
- DeviceID: keys.DeviceID,
- UserID: keys.UserID,
- KeyCount: make(map[string]int),
- }
- for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
- algo, keyID := keys.Split(keyIDWithAlgo)
- _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
- ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
- )
- if err != nil {
- return nil, err
- }
- }
- rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
- for rows.Next() {
- var algorithm string
- var count int
- if err = rows.Scan(&algorithm, &count); err != nil {
- return nil, err
- }
- counts.KeyCount[algorithm] = count
- }
-
- return counts, rows.Err()
-}
-
-func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
- ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
-) (map[string]json.RawMessage, error) {
- var keyID string
- var keyJSON string
- err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, nil
- }
- return nil, err
- }
- _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
- if err != nil {
- return nil, err
- }
- if keyJSON == "" {
- return nil, nil
- }
- return map[string]json.RawMessage{
- algorithm + ":" + keyID: json.RawMessage(keyJSON),
- }, err
-}
-
-func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
- _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID)
- return err
-}
diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go
deleted file mode 100644
index fd76a6e3..00000000
--- a/keyserver/storage/sqlite3/stale_device_lists.go
+++ /dev/null
@@ -1,145 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package sqlite3
-
-import (
- "context"
- "database/sql"
- "strings"
- "time"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
-
- "github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-var staleDeviceListsSchema = `
--- Stores whether a user's device lists are stale or not.
-CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
- user_id TEXT PRIMARY KEY NOT NULL,
- domain TEXT NOT NULL,
- is_stale BOOLEAN NOT NULL,
- ts_added_secs BIGINT NOT NULL
-);
-
-CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
-`
-
-const upsertStaleDeviceListSQL = "" +
- "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
- " VALUES ($1, $2, $3, $4)" +
- " ON CONFLICT (user_id)" +
- " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
-
-const selectStaleDeviceListsWithDomainsSQL = "" +
- "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
-
-const selectStaleDeviceListsSQL = "" +
- "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
-
-const deleteStaleDevicesSQL = "" +
- "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)"
-
-type staleDeviceListsStatements struct {
- db *sql.DB
- upsertStaleDeviceListStmt *sql.Stmt
- selectStaleDeviceListsWithDomainsStmt *sql.Stmt
- selectStaleDeviceListsStmt *sql.Stmt
- // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime
-}
-
-func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
- s := &staleDeviceListsStatements{
- db: db,
- }
- _, err := db.Exec(staleDeviceListsSchema)
- if err != nil {
- return nil, err
- }
- return s, sqlutil.StatementList{
- {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
- {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
- {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
- // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime
- }.Prepare(db)
-}
-
-func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
- _, domain, err := gomatrixserverlib.SplitID('@', userID)
- if err != nil {
- return err
- }
- _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
- return err
-}
-
-func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
- // we only query for 1 domain or all domains so optimise for those use cases
- if len(domains) == 0 {
- rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
- if err != nil {
- return nil, err
- }
- return rowsToUserIDs(ctx, rows)
- }
- var result []string
- for _, domain := range domains {
- rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
- if err != nil {
- return nil, err
- }
- userIDs, err := rowsToUserIDs(ctx, rows)
- if err != nil {
- return nil, err
- }
- result = append(result, userIDs...)
- }
- return result, nil
-}
-
-// DeleteStaleDeviceLists removes users from stale device lists
-func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
- ctx context.Context, txn *sql.Tx, userIDs []string,
-) error {
- qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
- stmt, err := s.db.Prepare(qry)
- if err != nil {
- return err
- }
- defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed")
- stmt = sqlutil.TxStmt(txn, stmt)
-
- params := make([]any, len(userIDs))
- for i := range userIDs {
- params[i] = userIDs[i]
- }
-
- _, err = stmt.ExecContext(ctx, params...)
- return err
-}
-
-func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
- defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
- for rows.Next() {
- var userID string
- if err := rows.Scan(&userID); err != nil {
- return nil, err
- }
- result = append(result, userID)
- }
- return result, rows.Err()
-}
diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go
deleted file mode 100644
index 873fe3e2..00000000
--- a/keyserver/storage/sqlite3/storage.go
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package sqlite3
-
-import (
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/shared"
- "github.com/matrix-org/dendrite/setup/base"
- "github.com/matrix-org/dendrite/setup/config"
-)
-
-func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) {
- db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
- if err != nil {
- return nil, err
- }
- otk, err := NewSqliteOneTimeKeysTable(db)
- if err != nil {
- return nil, err
- }
- dk, err := NewSqliteDeviceKeysTable(db)
- if err != nil {
- return nil, err
- }
- kc, err := NewSqliteKeyChangesTable(db)
- if err != nil {
- return nil, err
- }
- sdl, err := NewSqliteStaleDeviceListsTable(db)
- if err != nil {
- return nil, err
- }
- csk, err := NewSqliteCrossSigningKeysTable(db)
- if err != nil {
- return nil, err
- }
- css, err := NewSqliteCrossSigningSigsTable(db)
- if err != nil {
- return nil, err
- }
-
- if err = kc.Prepare(); err != nil {
- return nil, err
- }
- d := &shared.Database{
- DB: db,
- Writer: writer,
- OneTimeKeysTable: otk,
- DeviceKeysTable: dk,
- KeyChangesTable: kc,
- StaleDeviceListsTable: sdl,
- CrossSigningKeysTable: csk,
- CrossSigningSigsTable: css,
- }
- return d, nil
-}
diff --git a/keyserver/storage/storage.go b/keyserver/storage/storage.go
deleted file mode 100644
index ab6a3540..00000000
--- a/keyserver/storage/storage.go
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//go:build !wasm
-// +build !wasm
-
-package storage
-
-import (
- "fmt"
-
- "github.com/matrix-org/dendrite/keyserver/storage/postgres"
- "github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
- "github.com/matrix-org/dendrite/setup/base"
- "github.com/matrix-org/dendrite/setup/config"
-)
-
-// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
-// and sets postgres connection parameters
-func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) {
- switch {
- case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(base, dbProperties)
- case dbProperties.ConnectionString.IsPostgres():
- return postgres.NewDatabase(base, dbProperties)
- default:
- return nil, fmt.Errorf("unexpected database type")
- }
-}
diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go
deleted file mode 100644
index e7a2af7c..00000000
--- a/keyserver/storage/storage_test.go
+++ /dev/null
@@ -1,197 +0,0 @@
-package storage_test
-
-import (
- "context"
- "reflect"
- "sync"
- "testing"
-
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/dendrite/test"
- "github.com/matrix-org/dendrite/test/testrig"
-)
-
-var ctx = context.Background()
-
-func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
- base, close := testrig.CreateBaseDendrite(t, dbType)
- db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database)
- if err != nil {
- t.Fatalf("failed to create new database: %v", err)
- }
- return db, close
-}
-
-func MustNotError(t *testing.T, err error) {
- t.Helper()
- if err == nil {
- return
- }
- t.Fatalf("operation failed: %s", err)
-}
-
-func TestKeyChanges(t *testing.T) {
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, clean := MustCreateDatabase(t, dbType)
- defer clean()
- _, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
- MustNotError(t, err)
- deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
- MustNotError(t, err)
- userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
- if err != nil {
- t.Fatalf("Failed to KeyChanges: %s", err)
- }
- if latest != deviceChangeIDC {
- t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
- }
- if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
- t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
- }
- })
-}
-
-func TestKeyChangesNoDupes(t *testing.T) {
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, clean := MustCreateDatabase(t, dbType)
- defer clean()
- deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- if deviceChangeIDA == deviceChangeIDB {
- t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
- }
- deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
- if err != nil {
- t.Fatalf("Failed to KeyChanges: %s", err)
- }
- if latest != deviceChangeID {
- t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
- }
- if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
- t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
- }
- })
-}
-
-func TestKeyChangesUpperLimit(t *testing.T) {
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, clean := MustCreateDatabase(t, dbType)
- defer clean()
- deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
- MustNotError(t, err)
- _, err = db.StoreKeyChange(ctx, "@charlie:localhost")
- MustNotError(t, err)
- userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
- if err != nil {
- t.Fatalf("Failed to KeyChanges: %s", err)
- }
- if latest != deviceChangeIDB {
- t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
- }
- if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
- t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
- }
- })
-}
-
-var dbLock sync.Mutex
-var deviceArray = []string{"AAA", "another_device"}
-
-// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
-// and that they are returned correctly when querying for device keys.
-func TestDeviceKeysStreamIDGeneration(t *testing.T) {
- var err error
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, clean := MustCreateDatabase(t, dbType)
- defer clean()
- alice := "@alice:TestDeviceKeysStreamIDGeneration"
- bob := "@bob:TestDeviceKeysStreamIDGeneration"
- msgs := []api.DeviceMessage{
- {
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: "AAA",
- UserID: alice,
- KeyJSON: []byte(`{"key":"v1"}`),
- },
- // StreamID: 1
- },
- {
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: "AAA",
- UserID: bob,
- KeyJSON: []byte(`{"key":"v1"}`),
- },
- // StreamID: 1 as this is a different user
- },
- {
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: "another_device",
- UserID: alice,
- KeyJSON: []byte(`{"key":"v1"}`),
- },
- // StreamID: 2 as this is a 2nd device key
- },
- }
- MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
- if msgs[0].StreamID != 1 {
- t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
- }
- if msgs[1].StreamID != 1 {
- t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
- }
- if msgs[2].StreamID != 2 {
- t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
- }
-
- // updating a device sets the next stream ID for that user
- msgs = []api.DeviceMessage{
- {
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: "AAA",
- UserID: alice,
- KeyJSON: []byte(`{"key":"v2"}`),
- },
- // StreamID: 3
- },
- }
- MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
- if msgs[0].StreamID != 3 {
- t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
- }
-
- dbLock.Lock()
- defer dbLock.Unlock()
- // Querying for device keys returns the latest stream IDs
- msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false)
-
- if err != nil {
- t.Fatalf("DeviceKeysForUser returned error: %s", err)
- }
- wantStreamIDs := map[string]int64{
- "AAA": 3,
- "another_device": 2,
- }
- if len(msgs) != len(wantStreamIDs) {
- t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
- }
- for _, m := range msgs {
- if m.StreamID != wantStreamIDs[m.DeviceID] {
- t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
- }
- }
- })
-}
diff --git a/keyserver/storage/storage_wasm.go b/keyserver/storage/storage_wasm.go
deleted file mode 100644
index 75c9053e..00000000
--- a/keyserver/storage/storage_wasm.go
+++ /dev/null
@@ -1,34 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package storage
-
-import (
- "fmt"
-
- "github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
- "github.com/matrix-org/dendrite/setup/base"
- "github.com/matrix-org/dendrite/setup/config"
-)
-
-func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) {
- switch {
- case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(base, dbProperties)
- case dbProperties.ConnectionString.IsPostgres():
- return nil, fmt.Errorf("can't use Postgres implementation")
- default:
- return nil, fmt.Errorf("unexpected database type")
- }
-}
diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go
deleted file mode 100644
index 24da1125..00000000
--- a/keyserver/storage/tables/interface.go
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tables
-
-import (
- "context"
- "database/sql"
- "encoding/json"
-
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-type OneTimeKeys interface {
- SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
- CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
- InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
- // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
- // Returns an empty map if the key does not exist.
- SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
- DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
-}
-
-type DeviceKeys interface {
- SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
- InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
- SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
- CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
- SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
- DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
- DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
-}
-
-type KeyChanges interface {
- InsertKeyChange(ctx context.Context, userID string) (int64, error)
- // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
- // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset.
- SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
-
- Prepare() error
-}
-
-type StaleDeviceLists interface {
- InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
- SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
- DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
-}
-
-type CrossSigningKeys interface {
- SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error)
- UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error
-}
-
-type CrossSigningSigs interface {
- SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
- UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
- DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error
-}
diff --git a/keyserver/storage/tables/stale_device_lists_test.go b/keyserver/storage/tables/stale_device_lists_test.go
deleted file mode 100644
index 76d3badd..00000000
--- a/keyserver/storage/tables/stale_device_lists_test.go
+++ /dev/null
@@ -1,94 +0,0 @@
-package tables_test
-
-import (
- "context"
- "testing"
-
- "github.com/matrix-org/gomatrixserverlib"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
- "github.com/matrix-org/dendrite/setup/config"
-
- "github.com/matrix-org/dendrite/keyserver/storage/postgres"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
- "github.com/matrix-org/dendrite/test"
-)
-
-func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {
- connStr, close := test.PrepareDBConnectionString(t, dbType)
- db, err := sqlutil.Open(&config.DatabaseOptions{
- ConnectionString: config.DataSource(connStr),
- }, nil)
- if err != nil {
- t.Fatalf("failed to open database: %s", err)
- }
- switch dbType {
- case test.DBTypePostgres:
- tab, err = postgres.NewPostgresStaleDeviceListsTable(db)
- case test.DBTypeSQLite:
- tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db)
- }
- if err != nil {
- t.Fatalf("failed to create new table: %s", err)
- }
- return tab, close
-}
-
-func TestStaleDeviceLists(t *testing.T) {
- alice := test.NewUser(t)
- bob := test.NewUser(t)
- charlie := "@charlie:localhost"
- ctx := context.Background()
-
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- tab, closeDB := mustCreateTable(t, dbType)
- defer closeDB()
-
- if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil {
- t.Fatalf("failed to insert stale device: %s", err)
- }
- if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil {
- t.Fatalf("failed to insert stale device: %s", err)
- }
- if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil {
- t.Fatalf("failed to insert stale device: %s", err)
- }
-
- // Query one server
- wantStaleUsers := []string{alice.ID, bob.ID}
- gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
- if err != nil {
- t.Fatalf("failed to query stale device lists: %s", err)
- }
- if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
- t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
- }
-
- // Query all servers
- wantStaleUsers = []string{alice.ID, bob.ID, charlie}
- gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{})
- if err != nil {
- t.Fatalf("failed to query stale device lists: %s", err)
- }
- if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
- t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
- }
-
- // Delete stale devices
- deleteUsers := []string{alice.ID, bob.ID}
- if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil {
- t.Fatalf("failed to delete stale device lists: %s", err)
- }
-
- // Verify we don't get anything back after deleting
- gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
- if err != nil {
- t.Fatalf("failed to query stale device lists: %s", err)
- }
-
- if gotCount := len(gotStaleUsers); gotCount > 0 {
- t.Fatalf("expected no stale users, got %d", gotCount)
- }
- })
-}
diff --git a/keyserver/types/storage.go b/keyserver/types/storage.go
deleted file mode 100644
index 7fb90454..00000000
--- a/keyserver/types/storage.go
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright 2021 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package types
-
-import (
- "math"
-
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-const (
- // OffsetNewest tells e.g. the database to get the most current data
- OffsetNewest int64 = math.MaxInt64
- // OffsetOldest tells e.g. the database to get the oldest data
- OffsetOldest int64 = 0
-)
-
-// KeyTypePurposeToInt maps a purpose to an integer, which is used in the
-// database to reduce the amount of space taken up by this column.
-var KeyTypePurposeToInt = map[gomatrixserverlib.CrossSigningKeyPurpose]int16{
- gomatrixserverlib.CrossSigningKeyPurposeMaster: 1,
- gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: 2,
- gomatrixserverlib.CrossSigningKeyPurposeUserSigning: 3,
-}
-
-// KeyTypeIntToPurpose maps an integer to a purpose, which is used in the
-// database to reduce the amount of space taken up by this column.
-var KeyTypeIntToPurpose = map[int16]gomatrixserverlib.CrossSigningKeyPurpose{
- 1: gomatrixserverlib.CrossSigningKeyPurposeMaster,
- 2: gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
- 3: gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
-}
-
-// Map of purpose -> public key
-type CrossSigningKeyMap map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.Base64Bytes
-
-// Map of user ID -> key ID -> signature
-type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes