diff options
author | Kegsay <kegan@matrix.org> | 2020-07-15 12:02:34 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-07-15 12:02:34 +0100 |
commit | 9dd2ed7f6513e8fa677dee8d7dafa33f9c7afdfc (patch) | |
tree | 5c09582128d156aa2b6629cdaae626b44357b48d /keyserver/storage/postgres | |
parent | b4c07995d68dbeffa2161920cb4cd61ea2be8389 (diff) |
Implement key uploads (#1202)
* Add storage layer for postgres/sqlite
* Return OTK counts when inserting new keys
* Hook up the key DB and make a test pass
* Convert postgres queries to be sqlite queries
* Blacklist test due to requiring rejected events
* Unbreak tests
* Update blacklist
Diffstat (limited to 'keyserver/storage/postgres')
-rw-r--r-- | keyserver/storage/postgres/device_keys_table.go | 97 | ||||
-rw-r--r-- | keyserver/storage/postgres/one_time_keys_table.go | 143 | ||||
-rw-r--r-- | keyserver/storage/postgres/storage.go | 42 |
3 files changed, 282 insertions, 0 deletions
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go new file mode 100644 index 00000000..b05ec093 --- /dev/null +++ b/keyserver/storage/postgres/device_keys_table.go @@ -0,0 +1,97 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/storage/tables" +) + +var deviceKeysSchema = ` +-- Stores device keys for users +CREATE TABLE IF NOT EXISTS keyserver_device_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- Clobber based on tuple of user/device. + CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) +); +` + +const upsertDeviceKeysSQL = "" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + + " DO UPDATE SET key_json = $4" + +const selectDeviceKeysSQL = "" + + "SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +type deviceKeysStatements struct { + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt +} + +func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { + s := &deviceKeysStatements{ + db: db, + } + _, err := db.Exec(deviceKeysSchema) + if err != nil { + return nil, err + } + if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil { + return nil, err + } + if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { + for i, key := range keys { + var keyJSONStr string + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr) + if err != nil && err != sql.ErrNoRows { + return err + } + // this will be '' when there is no device + keys[i].KeyJSON = []byte(keyJSONStr) + } + return nil +} + +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { + now := time.Now().Unix() + return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { + for _, key := range keys { + _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), + ) + if err != nil { + return err + } + } + return nil + }) +} diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go new file mode 100644 index 00000000..b8aee72b --- /dev/null +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -0,0 +1,143 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/storage/tables" +) + +var oneTimeKeysSchema = ` +-- Stores one-time public keys for users +CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + key_id TEXT NOT NULL, + algorithm TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- Clobber based on 4-uple of user/device/key/algorithm. + CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm) +); +` + +const upsertKeysSQL = "" + + "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" + + " DO UPDATE SET key_json = $6" + +const selectKeysSQL = "" + + "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" + +const selectKeysCountSQL = "" + + "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" + +type oneTimeKeysStatements struct { + db *sql.DB + upsertKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysCountStmt *sql.Stmt +} + +func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { + s := &oneTimeKeysStatements{ + db: db, + } + _, err := db.Exec(oneTimeKeysSchema) + if err != nil { + return nil, err + } + if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil { + return nil, err + } + if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { + return nil, err + } + if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") + + wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) + for _, ka := range keyIDsWithAlgorithms { + wantSet[ka] = true + } + + result := make(map[string]json.RawMessage) + for rows.Next() { + var keyID string + var algorithm string + var keyJSONStr string + if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil { + return nil, err + } + keyIDWithAlgo := algorithm + ":" + keyID + if wantSet[keyIDWithAlgo] { + result[keyIDWithAlgo] = json.RawMessage(keyJSONStr) + } + } + return result, rows.Err() +} + +func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { + now := time.Now().Unix() + counts := &api.OneTimeKeysCount{ + DeviceID: keys.DeviceID, + UserID: keys.UserID, + KeyCount: make(map[string]int), + } + return counts, sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := txn.Stmt(s.upsertKeysStmt).ExecContext( + ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), + ) + if err != nil { + return err + } + } + rows, err := txn.Stmt(s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + if err != nil { + return err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return err + } + counts.KeyCount[algorithm] = count + } + + return rows.Err() + }) +} diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go new file mode 100644 index 00000000..4f3217b6 --- /dev/null +++ b/keyserver/storage/postgres/storage.go @@ -0,0 +1,42 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/storage/shared" +) + +// NewDatabase creates a new sync server database +func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*shared.Database, error) { + var err error + db, err := sqlutil.Open("postgres", dbDataSourceName, dbProperties) + if err != nil { + return nil, err + } + otk, err := NewPostgresOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewPostgresDeviceKeysTable(db) + if err != nil { + return nil, err + } + return &shared.Database{ + DB: db, + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + }, nil +} |