aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-07-28 17:38:30 +0100
committerGitHub <noreply@github.com>2020-07-28 17:38:30 +0100
commitadf7b5929401f56bedba92ef778b5e56feefc479 (patch)
treeadb05d580ddffa9aaaf6a6634eed480aee00f7c9
parentacc8e80a51515c953c6710cb24f36fd9d1f7aeb1 (diff)
Persist partition|offset|user_id in the keyserver (#1226)
* Persist partition|offset|user_id in the keyserver Required for a query API which will be used by the syncapi which will be called when a `/sync` request comes in which will return a list of user IDs of people who have changed their device keys between two tokens. * Add tests and fix maxOffset bug * s/offset/log_offset/g because 'offset' is a reserved word in postgres
-rw-r--r--keyserver/keyserver.go1
-rw-r--r--keyserver/producers/keychange.go7
-rw-r--r--keyserver/storage/interface.go8
-rw-r--r--keyserver/storage/postgres/key_changes_table.go97
-rw-r--r--keyserver/storage/postgres/storage.go5
-rw-r--r--keyserver/storage/shared/storage.go9
-rw-r--r--keyserver/storage/sqlite3/key_changes_table.go98
-rw-r--r--keyserver/storage/sqlite3/storage.go5
-rw-r--r--keyserver/storage/storage_test.go57
-rw-r--r--keyserver/storage/tables/interface.go5
10 files changed, 292 insertions, 0 deletions
diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go
index 47c6a8c3..c748d7ce 100644
--- a/keyserver/keyserver.go
+++ b/keyserver/keyserver.go
@@ -49,6 +49,7 @@ func NewInternalAPI(
keyChangeProducer := &producers.KeyChange{
Topic: string(cfg.Kafka.Topics.OutputKeyChangeEvent),
Producer: producer,
+ DB: db,
}
return &internal.KeyInternalAPI{
DB: db,
diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go
index 6683a936..d59dd200 100644
--- a/keyserver/producers/keychange.go
+++ b/keyserver/producers/keychange.go
@@ -15,10 +15,12 @@
package producers
import (
+ "context"
"encoding/json"
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/keyserver/api"
+ "github.com/matrix-org/dendrite/keyserver/storage"
"github.com/sirupsen/logrus"
)
@@ -26,6 +28,7 @@ import (
type KeyChange struct {
Topic string
Producer sarama.SyncProducer
+ DB storage.Database
}
// ProduceKeyChanges creates new change events for each key
@@ -46,6 +49,10 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error {
if err != nil {
return err
}
+ err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID)
+ if err != nil {
+ return err
+ }
logrus.WithFields(logrus.Fields{
"user_id": key.UserID,
"device_id": key.DeviceID,
diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go
index 7a0328bd..f4787790 100644
--- a/keyserver/storage/interface.go
+++ b/keyserver/storage/interface.go
@@ -43,4 +43,12 @@ type Database interface {
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
+
+ // StoreKeyChange stores key change metadata after the change has been sent to Kafka. `userID` is the the user who has changed
+ // their keys in some way.
+ StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
+
+ // KeyChanges returns a list of user IDs who have modified their keys from the offset given.
+ // Returns the offset of the latest key change.
+ KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error)
}
diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go
new file mode 100644
index 00000000..9d259f9f
--- /dev/null
+++ b/keyserver/storage/postgres/key_changes_table.go
@@ -0,0 +1,97 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/keyserver/storage/tables"
+)
+
+var keyChangesSchema = `
+-- Stores key change information about users. Used to determine when to send updated device lists to clients.
+CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ partition BIGINT NOT NULL,
+ log_offset BIGINT NOT NULL,
+ user_id TEXT NOT NULL,
+ CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
+);
+`
+
+// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
+// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
+// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
+const upsertKeyChangeSQL = "" +
+ "INSERT INTO keyserver_key_changes (partition, log_offset, user_id)" +
+ " VALUES ($1, $2, $3)" +
+ " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique" +
+ " DO UPDATE SET user_id = $3"
+
+// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
+// take the max offset value as the latest offset.
+const selectKeyChangesSQL = "" +
+ "SELECT user_id, MAX(log_offset) FROM keyserver_key_changes WHERE partition = $1 AND log_offset > $2 GROUP BY user_id"
+
+type keyChangesStatements struct {
+ db *sql.DB
+ upsertKeyChangeStmt *sql.Stmt
+ selectKeyChangesStmt *sql.Stmt
+}
+
+func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
+ s := &keyChangesStatements{
+ db: db,
+ }
+ _, err := db.Exec(keyChangesSchema)
+ if err != nil {
+ return nil, err
+ }
+ if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
+ return nil, err
+ }
+ if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
+ return nil, err
+ }
+ return s, nil
+}
+
+func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
+ _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
+ return err
+}
+
+func (s *keyChangesStatements) SelectKeyChanges(
+ ctx context.Context, partition int32, fromOffset int64,
+) (userIDs []string, latestOffset int64, err error) {
+ rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset)
+ if err != nil {
+ return nil, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
+ for rows.Next() {
+ var userID string
+ var offset int64
+ if err := rows.Scan(&userID, &offset); err != nil {
+ return nil, 0, err
+ }
+ if offset > latestOffset {
+ latestOffset = offset
+ }
+ userIDs = append(userIDs, userID)
+ }
+ return
+}
diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go
index 4f3217b6..a1d1c0fe 100644
--- a/keyserver/storage/postgres/storage.go
+++ b/keyserver/storage/postgres/storage.go
@@ -34,9 +34,14 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*s
if err != nil {
return nil, err
}
+ kc, err := NewPostgresKeyChangesTable(db)
+ if err != nil {
+ return nil, err
+ }
return &shared.Database{
DB: db,
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
+ KeyChangesTable: kc,
}, nil
}
diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go
index 156b5b41..537a5f7b 100644
--- a/keyserver/storage/shared/storage.go
+++ b/keyserver/storage/shared/storage.go
@@ -28,6 +28,7 @@ type Database struct {
DB *sql.DB
OneTimeKeysTable tables.OneTimeKeys
DeviceKeysTable tables.DeviceKeys
+ KeyChangesTable tables.KeyChanges
}
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
@@ -72,3 +73,11 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st
})
return result, err
}
+
+func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
+ return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID)
+}
+
+func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error) {
+ return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset)
+}
diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go
new file mode 100644
index 00000000..b830214d
--- /dev/null
+++ b/keyserver/storage/sqlite3/key_changes_table.go
@@ -0,0 +1,98 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/keyserver/storage/tables"
+)
+
+var keyChangesSchema = `
+-- Stores key change information about users. Used to determine when to send updated device lists to clients.
+CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ partition BIGINT NOT NULL,
+ offset BIGINT NOT NULL,
+ -- The key owner
+ user_id TEXT NOT NULL,
+ UNIQUE (partition, offset)
+);
+`
+
+// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
+// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
+// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
+const upsertKeyChangeSQL = "" +
+ "INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
+ " VALUES ($1, $2, $3)" +
+ " ON CONFLICT (partition, offset)" +
+ " DO UPDATE SET user_id = $3"
+
+// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
+// take the max offset value as the latest offset.
+const selectKeyChangesSQL = "" +
+ "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 GROUP BY user_id"
+
+type keyChangesStatements struct {
+ db *sql.DB
+ upsertKeyChangeStmt *sql.Stmt
+ selectKeyChangesStmt *sql.Stmt
+}
+
+func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
+ s := &keyChangesStatements{
+ db: db,
+ }
+ _, err := db.Exec(keyChangesSchema)
+ if err != nil {
+ return nil, err
+ }
+ if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
+ return nil, err
+ }
+ if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
+ return nil, err
+ }
+ return s, nil
+}
+
+func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
+ _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
+ return err
+}
+
+func (s *keyChangesStatements) SelectKeyChanges(
+ ctx context.Context, partition int32, fromOffset int64,
+) (userIDs []string, latestOffset int64, err error) {
+ rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset)
+ if err != nil {
+ return nil, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
+ for rows.Next() {
+ var userID string
+ var offset int64
+ if err := rows.Scan(&userID, &offset); err != nil {
+ return nil, 0, err
+ }
+ if offset > latestOffset {
+ latestOffset = offset
+ }
+ userIDs = append(userIDs, userID)
+ }
+ return
+}
diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go
index f3566ef5..f9771cf1 100644
--- a/keyserver/storage/sqlite3/storage.go
+++ b/keyserver/storage/sqlite3/storage.go
@@ -37,9 +37,14 @@ func NewDatabase(dataSourceName string) (*shared.Database, error) {
if err != nil {
return nil, err
}
+ kc, err := NewSqliteKeyChangesTable(db)
+ if err != nil {
+ return nil, err
+ }
return &shared.Database{
DB: db,
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
+ KeyChangesTable: kc,
}, nil
}
diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go
new file mode 100644
index 00000000..88972478
--- /dev/null
+++ b/keyserver/storage/storage_test.go
@@ -0,0 +1,57 @@
+package storage
+
+import (
+ "context"
+ "reflect"
+ "testing"
+)
+
+var ctx = context.Background()
+
+func MustNotError(t *testing.T, err error) {
+ t.Helper()
+ if err == nil {
+ return
+ }
+ t.Fatalf("operation failed: %s", err)
+}
+
+func TestKeyChanges(t *testing.T) {
+ db, err := NewDatabase("file::memory:", nil)
+ if err != nil {
+ t.Fatalf("Failed to NewDatabase: %s", err)
+ }
+ MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
+ MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
+ MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
+ userIDs, latest, err := db.KeyChanges(ctx, 0, 1)
+ if err != nil {
+ t.Fatalf("Failed to KeyChanges: %s", err)
+ }
+ if latest != 2 {
+ t.Fatalf("KeyChanges: got latest=%d want 2", latest)
+ }
+ if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
+ t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
+ }
+}
+
+func TestKeyChangesNoDupes(t *testing.T) {
+ db, err := NewDatabase("file::memory:", nil)
+ if err != nil {
+ t.Fatalf("Failed to NewDatabase: %s", err)
+ }
+ MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
+ MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
+ MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
+ userIDs, latest, err := db.KeyChanges(ctx, 0, 0)
+ if err != nil {
+ t.Fatalf("Failed to KeyChanges: %s", err)
+ }
+ if latest != 2 {
+ t.Fatalf("KeyChanges: got latest=%d want 2", latest)
+ }
+ if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
+ t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
+ }
+}
diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go
index 216be773..824b9f0f 100644
--- a/keyserver/storage/tables/interface.go
+++ b/keyserver/storage/tables/interface.go
@@ -35,3 +35,8 @@ type DeviceKeys interface {
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
}
+
+type KeyChanges interface {
+ InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
+ SelectKeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error)
+}