aboutsummaryrefslogtreecommitdiff
path: root/keyserver
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-08-07 17:32:13 +0100
committerGitHub <noreply@github.com>2020-08-07 17:32:13 +0100
commitf371783da765f96fc3764091e95fb8cb8004e208 (patch)
treedd1748eca0719054be508f2f026cb0e314d2565f /keyserver
parent30c2325eaf85f28f438f9a3c7b703978eee66cf7 (diff)
Finish inbound E2E device lists (#1243)
* Add tests for device list updates * Add stale_device_lists table and use db before asking remote for device keys * Fetch remote keys if all devices are requested * Add display_name col to store remote device names Few other tweaks to make `Server correctly handles incoming m.device_list_update` pass. * Fix sqlite otk bug * Unbuffered channel to block /send causing sytest to not race anymore * Linting and fix bug whereby we didn't send updated dl tokens to the client causing a tightloop on /sync sometimes * No longer assert staleness as Update blocks on workers now * Back out tweaks * Bugfixes
Diffstat (limited to 'keyserver')
-rw-r--r--keyserver/internal/device_list_update.go34
-rw-r--r--keyserver/internal/device_list_update_test.go242
-rw-r--r--keyserver/internal/internal.go47
-rw-r--r--keyserver/storage/postgres/device_keys_table.go25
-rw-r--r--keyserver/storage/postgres/stale_device_lists.go118
-rw-r--r--keyserver/storage/postgres/storage.go13
-rw-r--r--keyserver/storage/shared/storage.go13
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go25
-rw-r--r--keyserver/storage/sqlite3/one_time_keys_table.go3
-rw-r--r--keyserver/storage/sqlite3/stale_device_lists.go118
-rw-r--r--keyserver/storage/sqlite3/storage.go13
-rw-r--r--keyserver/storage/tables/interface.go6
12 files changed, 616 insertions, 41 deletions
diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go
index 19d8463d..ec7dff56 100644
--- a/keyserver/internal/device_list_update.go
+++ b/keyserver/internal/device_list_update.go
@@ -23,7 +23,6 @@ import (
"time"
"github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/producers"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@@ -65,7 +64,7 @@ type DeviceListUpdater struct {
mu *sync.Mutex // protects UserIDToMutex
db DeviceListUpdaterDatabase
- producer *producers.KeyChange
+ producer KeyChangeProducer
fedClient *gomatrixserverlib.FederationClient
workerChans []chan gomatrixserverlib.ServerName
}
@@ -88,9 +87,14 @@ type DeviceListUpdaterDatabase interface {
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
}
+// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
+type KeyChangeProducer interface {
+ ProduceKeyChanges(keys []api.DeviceMessage) error
+}
+
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
func NewDeviceListUpdater(
- db DeviceListUpdaterDatabase, producer *producers.KeyChange, fedClient *gomatrixserverlib.FederationClient,
+ db DeviceListUpdaterDatabase, producer KeyChangeProducer, fedClient *gomatrixserverlib.FederationClient,
numWorkers int,
) *DeviceListUpdater {
return &DeviceListUpdater{
@@ -154,12 +158,17 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
if err != nil {
return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
}
+ // if this is the first time we're hearing about this user, sync the device list manually.
+ if len(event.PrevID) == 0 {
+ exists = false
+ }
util.GetLogger(ctx).WithFields(logrus.Fields{
"prev_ids_exist": exists,
"user_id": event.UserID,
"device_id": event.DeviceID,
"stream_id": event.StreamID,
"prev_ids": event.PrevID,
+ "display_name": event.DeviceDisplayName,
}).Info("DeviceListUpdater.Update")
// if we haven't missed anything update the database and notify users
@@ -263,16 +272,17 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
hasFailures = true
continue
}
- err = u.updateDeviceList(ctx, &res)
+ err = u.updateDeviceList(&res)
if err != nil {
- logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store it")
+ logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it")
hasFailures = true
}
}
return hasFailures
}
-func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixserverlib.RespUserDevices) error {
+func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error {
+ ctx := context.Background() // we've got the keys, don't time out when persisting them to the database.
keys := make([]api.DeviceMessage, len(res.Devices))
for i, device := range res.Devices {
keyJSON, err := json.Marshal(device.Keys)
@@ -292,7 +302,15 @@ func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixs
}
err := u.db.StoreRemoteDeviceKeys(ctx, keys)
if err != nil {
- return err
+ return fmt.Errorf("failed to store remote device keys: %w", err)
}
- return u.db.MarkDeviceListStale(ctx, res.UserID, false)
+ err = u.db.MarkDeviceListStale(ctx, res.UserID, false)
+ if err != nil {
+ return fmt.Errorf("failed to mark device list as fresh: %w", err)
+ }
+ err = u.producer.ProduceKeyChanges(keys)
+ if err != nil {
+ return fmt.Errorf("failed to emit key changes for fresh device list: %w", err)
+ }
+ return nil
}
diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go
new file mode 100644
index 00000000..50e42763
--- /dev/null
+++ b/keyserver/internal/device_list_update_test.go
@@ -0,0 +1,242 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ "context"
+ "crypto/ed25519"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "reflect"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/matrix-org/dendrite/keyserver/api"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var (
+ ctx = context.Background()
+)
+
+type mockKeyChangeProducer struct {
+ events []api.DeviceMessage
+}
+
+func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error {
+ p.events = append(p.events, keys...)
+ return nil
+}
+
+type mockDeviceListUpdaterDatabase struct {
+ staleUsers map[string]bool
+ prevIDsExist func(string, []int) bool
+ storedKeys []api.DeviceMessage
+}
+
+// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
+// If no domains are given, all user IDs with stale device lists are returned.
+func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
+ var result []string
+ for userID := range d.staleUsers {
+ _, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return nil, err
+ }
+ if len(domains) == 0 {
+ result = append(result, userID)
+ continue
+ }
+ for _, d := range domains {
+ if remoteServer == d {
+ result = append(result, userID)
+ break
+ }
+ }
+ }
+ return result, nil
+}
+
+// MarkDeviceListStale sets the stale bit for this user to isStale.
+func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
+ d.staleUsers[userID] = isStale
+ return nil
+}
+
+// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
+// for this (user, device). Does not modify the stream ID for keys.
+func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
+ d.storedKeys = append(d.storedKeys, keys...)
+ return nil
+}
+
+// PrevIDsExists returns true if all prev IDs exist for this user.
+func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) {
+ return d.prevIDsExist(userID, prevIDs), nil
+}
+
+type roundTripper struct {
+ fn func(*http.Request) (*http.Response, error)
+}
+
+func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ return t.fn(req)
+}
+
+func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
+ _, pkey, _ := ed25519.GenerateKey(nil)
+ fedClient := gomatrixserverlib.NewFederationClient(
+ gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey,
+ )
+ fedClient.Client = *gomatrixserverlib.NewClientWithTransport(&roundTripper{tripper})
+ return fedClient
+}
+
+// Test that the device keys get persisted and emitted if we have the previous IDs.
+func TestUpdateHavePrevID(t *testing.T) {
+ db := &mockDeviceListUpdaterDatabase{
+ staleUsers: make(map[string]bool),
+ prevIDsExist: func(string, []int) bool {
+ return true
+ },
+ }
+ producer := &mockKeyChangeProducer{}
+ updater := NewDeviceListUpdater(db, producer, nil, 1)
+ event := gomatrixserverlib.DeviceListUpdateEvent{
+ DeviceDisplayName: "Foo Bar",
+ Deleted: false,
+ DeviceID: "FOO",
+ Keys: []byte(`{"key":"value"}`),
+ PrevID: []int{0},
+ StreamID: 1,
+ UserID: "@alice:localhost",
+ }
+ err := updater.Update(ctx, event)
+ if err != nil {
+ t.Fatalf("Update returned an error: %s", err)
+ }
+ want := api.DeviceMessage{
+ StreamID: event.StreamID,
+ DeviceKeys: api.DeviceKeys{
+ DeviceID: event.DeviceID,
+ DisplayName: event.DeviceDisplayName,
+ KeyJSON: event.Keys,
+ UserID: event.UserID,
+ },
+ }
+ if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
+ t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
+ }
+ if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
+ t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
+ }
+ if db.staleUsers[event.UserID] {
+ t.Errorf("%s incorrectly marked as stale", event.UserID)
+ }
+}
+
+// Test that device keys are fetched from the remote server if we are missing prev IDs
+// and that the user's devices are marked as stale until it succeeds.
+func TestUpdateNoPrevID(t *testing.T) {
+ db := &mockDeviceListUpdaterDatabase{
+ staleUsers: make(map[string]bool),
+ prevIDsExist: func(string, []int) bool {
+ return false
+ },
+ }
+ producer := &mockKeyChangeProducer{}
+ remoteUserID := "@alice:example.somewhere"
+ var wg sync.WaitGroup
+ wg.Add(1)
+ keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
+ fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
+ defer wg.Done()
+ if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) {
+ return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
+ }
+ return &http.Response{
+ StatusCode: 200,
+ Body: ioutil.NopCloser(strings.NewReader(`
+ {
+ "user_id": "` + remoteUserID + `",
+ "stream_id": 5,
+ "devices": [
+ {
+ "device_id": "JLAFKJWSCS",
+ "keys": ` + keyJSON + `,
+ "device_display_name": "Mobile Phone"
+ }
+ ]
+ }
+ `)),
+ }, nil
+ })
+ updater := NewDeviceListUpdater(db, producer, fedClient, 2)
+ if err := updater.Start(); err != nil {
+ t.Fatalf("failed to start updater: %s", err)
+ }
+ event := gomatrixserverlib.DeviceListUpdateEvent{
+ DeviceDisplayName: "Mobile Phone",
+ Deleted: false,
+ DeviceID: "another_device_id",
+ Keys: []byte(`{"key":"value"}`),
+ PrevID: []int{3},
+ StreamID: 4,
+ UserID: remoteUserID,
+ }
+ err := updater.Update(ctx, event)
+ if err != nil {
+ t.Fatalf("Update returned an error: %s", err)
+ }
+ // At this point we show have this device list marked as stale and not store the keys or emitted anything
+ if !db.staleUsers[event.UserID] {
+ t.Errorf("%s not marked as stale", event.UserID)
+ }
+ if len(producer.events) > 0 {
+ t.Errorf("Update incorrect emitted %d device change events", len(producer.events))
+ }
+ if len(db.storedKeys) > 0 {
+ t.Errorf("Update incorrect stored %d device change events", len(db.storedKeys))
+ }
+ t.Log("waiting for /users/devices to be called...")
+ wg.Wait()
+ // wait a bit for db to be updated...
+ time.Sleep(100 * time.Millisecond)
+ want := api.DeviceMessage{
+ StreamID: 5,
+ DeviceKeys: api.DeviceKeys{
+ DeviceID: "JLAFKJWSCS",
+ DisplayName: "Mobile Phone",
+ UserID: remoteUserID,
+ KeyJSON: []byte(keyJSON),
+ },
+ }
+ // Now we should have a fresh list and the keys and emitted something
+ if db.staleUsers[event.UserID] {
+ t.Errorf("%s still marked as stale", event.UserID)
+ }
+ if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
+ t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON))
+ t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
+ }
+ if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
+ t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
+ }
+
+}
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index ff298c07..075622b7 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -250,10 +250,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
if len(dk.KeyJSON) == 0 {
continue // don't include blank keys
}
- // inject display name if known
+ // inject display name if known (either locally or remotely)
+ displayName := dk.DisplayName
+ if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
+ displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
+ }
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
DisplayName string `json:"device_display_name,omitempty"`
- }{queryRes.DeviceInfo[dk.DeviceID].DisplayName})
+ }{displayName})
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
}
} else {
@@ -261,12 +265,49 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
}
}
- // TODO: set device display names when they are known
+
+ // attempt to satisfy key queries from the local database first as we should get device updates pushed to us
+ domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
+ if len(domainToDeviceKeys) == 0 {
+ return // nothing to query
+ }
// perform key queries for remote devices
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
}
+func (a *KeyInternalAPI) remoteKeysFromDatabase(
+ ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
+) map[string]map[string][]string {
+ fetchRemote := make(map[string]map[string][]string)
+ for domain, userToDeviceMap := range domainToDeviceKeys {
+ for userID, deviceIDs := range userToDeviceMap {
+ keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
+ // if we can't query the db or there are fewer keys than requested, fetch from remote.
+ // Likewise, we can't safely return keys from the db when all devices are requested as we don't
+ // know if one has just been added.
+ if len(deviceIDs) == 0 || err != nil || len(keys) < len(deviceIDs) {
+ if _, ok := fetchRemote[domain]; !ok {
+ fetchRemote[domain] = make(map[string][]string)
+ }
+ fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
+ continue
+ }
+ if res.DeviceKeys[userID] == nil {
+ res.DeviceKeys[userID] = make(map[string]json.RawMessage)
+ }
+ for _, key := range keys {
+ // inject the display name
+ key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
+ DisplayName string `json:"device_display_name,omitempty"`
+ }{key.DisplayName})
+ res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
+ }
+ }
+ }
+ return fetchRemote
+}
+
func (a *KeyInternalAPI) queryRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
) {
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go
index d321860d..b9d5d4c3 100644
--- a/keyserver/storage/postgres/device_keys_table.go
+++ b/keyserver/storage/postgres/device_keys_table.go
@@ -37,22 +37,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
-- required in the spec because in the event of a missed update the server fetches the entire
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
stream_id BIGINT NOT NULL,
+ display_name TEXT,
-- Clobber based on tuple of user/device.
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
);
`
const upsertDeviceKeysSQL = "" +
- "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
- " VALUES ($1, $2, $3, $4, $5)" +
+ "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
- " DO UPDATE SET key_json = $4, stream_id = $5"
+ " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
const selectDeviceKeysSQL = "" +
- "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+ "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
- "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@@ -99,13 +100,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
for i, key := range keys {
var keyJSONStr string
var streamID int
- err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
+ var displayName sql.NullString
+ err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
+ if displayName.Valid {
+ keys[i].DisplayName = displayName.String
+ }
}
return nil
}
@@ -140,7 +145,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
for _, key := range keys {
now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
- ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
+ ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
)
if err != nil {
return err
@@ -165,11 +170,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
dk.UserID = userID
var keyJSON string
var streamID int
- if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
+ var displayName sql.NullString
+ if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID
+ if displayName.Valid {
+ dk.DisplayName = displayName.String
+ }
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)
diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go
new file mode 100644
index 00000000..63281adf
--- /dev/null
+++ b/keyserver/storage/postgres/stale_device_lists.go
@@ -0,0 +1,118 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/keyserver/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var staleDeviceListsSchema = `
+-- Stores whether a user's device lists are stale or not.
+CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
+ user_id TEXT PRIMARY KEY NOT NULL,
+ domain TEXT NOT NULL,
+ is_stale BOOLEAN NOT NULL,
+ ts_added_secs BIGINT NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
+`
+
+const upsertStaleDeviceListSQL = "" +
+ "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
+ " VALUES ($1, $2, $3, $4)" +
+ " ON CONFLICT (user_id)" +
+ " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
+
+const selectStaleDeviceListsWithDomainsSQL = "" +
+ "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
+
+const selectStaleDeviceListsSQL = "" +
+ "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
+
+type staleDeviceListsStatements struct {
+ upsertStaleDeviceListStmt *sql.Stmt
+ selectStaleDeviceListsWithDomainsStmt *sql.Stmt
+ selectStaleDeviceListsStmt *sql.Stmt
+}
+
+func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
+ s := &staleDeviceListsStatements{}
+ _, err := db.Exec(staleDeviceListsSchema)
+ if err != nil {
+ return nil, err
+ }
+ if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
+ return nil, err
+ }
+ if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
+ return nil, err
+ }
+ if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
+ return nil, err
+ }
+ return s, nil
+}
+
+func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
+ _, domain, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return err
+ }
+ _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
+ return err
+}
+
+func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
+ // we only query for 1 domain or all domains so optimise for those use cases
+ if len(domains) == 0 {
+ rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
+ if err != nil {
+ return nil, err
+ }
+ return rowsToUserIDs(ctx, rows)
+ }
+ var result []string
+ for _, domain := range domains {
+ rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
+ if err != nil {
+ return nil, err
+ }
+ userIDs, err := rowsToUserIDs(ctx, rows)
+ if err != nil {
+ return nil, err
+ }
+ result = append(result, userIDs...)
+ }
+ return result, nil
+}
+
+func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
+ defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
+ for rows.Next() {
+ var userID string
+ if err := rows.Scan(&userID); err != nil {
+ return nil, err
+ }
+ result = append(result, userID)
+ }
+ return result, rows.Err()
+}
diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go
index a1d1c0fe..de2fabfd 100644
--- a/keyserver/storage/postgres/storage.go
+++ b/keyserver/storage/postgres/storage.go
@@ -38,10 +38,15 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*s
if err != nil {
return nil, err
}
+ sdl, err := NewPostgresStaleDeviceListsTable(db)
+ if err != nil {
+ return nil, err
+ }
return &shared.Database{
- DB: db,
- OneTimeKeysTable: otk,
- DeviceKeysTable: dk,
- KeyChangesTable: kc,
+ DB: db,
+ OneTimeKeysTable: otk,
+ DeviceKeysTable: dk,
+ KeyChangesTable: kc,
+ StaleDeviceListsTable: sdl,
}, nil
}
diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go
index 68964be6..4279eae7 100644
--- a/keyserver/storage/shared/storage.go
+++ b/keyserver/storage/shared/storage.go
@@ -26,10 +26,11 @@ import (
)
type Database struct {
- DB *sql.DB
- OneTimeKeysTable tables.OneTimeKeys
- DeviceKeysTable tables.DeviceKeys
- KeyChangesTable tables.KeyChanges
+ DB *sql.DB
+ OneTimeKeysTable tables.OneTimeKeys
+ DeviceKeysTable tables.DeviceKeys
+ KeyChangesTable tables.KeyChanges
+ StaleDeviceListsTable tables.StaleDeviceLists
}
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
@@ -129,10 +130,10 @@ func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset,
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
- return nil, nil // TODO
+ return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
}
// MarkDeviceListStale sets the stale bit for this user to isStale.
func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
- return nil // TODO
+ return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
}
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
index 15d9c775..abe6636a 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/keyserver/storage/sqlite3/device_keys_table.go
@@ -34,22 +34,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
stream_id BIGINT NOT NULL,
+ display_name TEXT,
-- Clobber based on tuple of user/device.
UNIQUE (user_id, device_id)
);
`
const upsertDeviceKeysSQL = "" +
- "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
- " VALUES ($1, $2, $3, $4, $5)" +
+ "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT (user_id, device_id)" +
- " DO UPDATE SET key_json = $4, stream_id = $5"
+ " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
const selectDeviceKeysSQL = "" +
- "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+ "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
- "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
@@ -106,11 +107,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
dk.UserID = userID
var keyJSON string
var streamID int
- if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
+ var displayName sql.NullString
+ if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID
+ if displayName.Valid {
+ dk.DisplayName = displayName.String
+ }
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)
@@ -123,13 +128,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
for i, key := range keys {
var keyJSONStr string
var streamID int
- err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
+ var displayName sql.NullString
+ err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
+ if displayName.Valid {
+ keys[i].DisplayName = displayName.String
+ }
}
return nil
}
@@ -171,7 +180,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
for _, key := range keys {
now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
- ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
+ ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
)
if err != nil {
return err
diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go
index f910479f..907966a7 100644
--- a/keyserver/storage/sqlite3/one_time_keys_table.go
+++ b/keyserver/storage/sqlite3/one_time_keys_table.go
@@ -196,6 +196,9 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return err
})
+ if keyJSON == "" {
+ return nil, nil
+ }
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go
new file mode 100644
index 00000000..a989476d
--- /dev/null
+++ b/keyserver/storage/sqlite3/stale_device_lists.go
@@ -0,0 +1,118 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/keyserver/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var staleDeviceListsSchema = `
+-- Stores whether a user's device lists are stale or not.
+CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
+ user_id TEXT PRIMARY KEY NOT NULL,
+ domain TEXT NOT NULL,
+ is_stale BOOLEAN NOT NULL,
+ ts_added_secs BIGINT NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
+`
+
+const upsertStaleDeviceListSQL = "" +
+ "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
+ " VALUES ($1, $2, $3, $4)" +
+ " ON CONFLICT (user_id)" +
+ " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
+
+const selectStaleDeviceListsWithDomainsSQL = "" +
+ "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2"
+
+const selectStaleDeviceListsSQL = "" +
+ "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
+
+type staleDeviceListsStatements struct {
+ upsertStaleDeviceListStmt *sql.Stmt
+ selectStaleDeviceListsWithDomainsStmt *sql.Stmt
+ selectStaleDeviceListsStmt *sql.Stmt
+}
+
+func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
+ s := &staleDeviceListsStatements{}
+ _, err := db.Exec(staleDeviceListsSchema)
+ if err != nil {
+ return nil, err
+ }
+ if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
+ return nil, err
+ }
+ if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
+ return nil, err
+ }
+ if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
+ return nil, err
+ }
+ return s, nil
+}
+
+func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
+ _, domain, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return err
+ }
+ _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
+ return err
+}
+
+func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
+ // we only query for 1 domain or all domains so optimise for those use cases
+ if len(domains) == 0 {
+ rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
+ if err != nil {
+ return nil, err
+ }
+ return rowsToUserIDs(ctx, rows)
+ }
+ var result []string
+ for _, domain := range domains {
+ rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
+ if err != nil {
+ return nil, err
+ }
+ userIDs, err := rowsToUserIDs(ctx, rows)
+ if err != nil {
+ return nil, err
+ }
+ result = append(result, userIDs...)
+ }
+ return result, nil
+}
+
+func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
+ defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
+ for rows.Next() {
+ var userID string
+ if err := rows.Scan(&userID); err != nil {
+ return nil, err
+ }
+ result = append(result, userID)
+ }
+ return result, rows.Err()
+}
diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go
index f9771cf1..bbfd1e79 100644
--- a/keyserver/storage/sqlite3/storage.go
+++ b/keyserver/storage/sqlite3/storage.go
@@ -41,10 +41,15 @@ func NewDatabase(dataSourceName string) (*shared.Database, error) {
if err != nil {
return nil, err
}
+ sdl, err := NewSqliteStaleDeviceListsTable(db)
+ if err != nil {
+ return nil, err
+ }
return &shared.Database{
- DB: db,
- OneTimeKeysTable: otk,
- DeviceKeysTable: dk,
- KeyChangesTable: kc,
+ DB: db,
+ OneTimeKeysTable: otk,
+ DeviceKeysTable: dk,
+ KeyChangesTable: kc,
+ StaleDeviceListsTable: sdl,
}, nil
}
diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go
index ac932d56..a4d5dede 100644
--- a/keyserver/storage/tables/interface.go
+++ b/keyserver/storage/tables/interface.go
@@ -20,6 +20,7 @@ import (
"encoding/json"
"github.com/matrix-org/dendrite/keyserver/api"
+ "github.com/matrix-org/gomatrixserverlib"
)
type OneTimeKeys interface {
@@ -45,3 +46,8 @@ type KeyChanges interface {
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset.
SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
}
+
+type StaleDeviceLists interface {
+ InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
+ SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
+}