aboutsummaryrefslogtreecommitdiff
path: root/keyserver/storage/postgres/device_keys_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'keyserver/storage/postgres/device_keys_table.go')
-rw-r--r--keyserver/storage/postgres/device_keys_table.go228
1 files changed, 0 insertions, 228 deletions
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go
deleted file mode 100644
index 2aa11c52..00000000
--- a/keyserver/storage/postgres/device_keys_table.go
+++ /dev/null
@@ -1,228 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package postgres
-
-import (
- "context"
- "database/sql"
- "time"
-
- "github.com/lib/pq"
-
- "github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage/tables"
-)
-
-var deviceKeysSchema = `
--- Stores device keys for users
-CREATE TABLE IF NOT EXISTS keyserver_device_keys (
- user_id TEXT NOT NULL,
- device_id TEXT NOT NULL,
- ts_added_secs BIGINT NOT NULL,
- key_json TEXT NOT NULL,
- -- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
- -- This means we do not store an unbounded append-only log of device keys, which is not actually
- -- required in the spec because in the event of a missed update the server fetches the entire
- -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
- stream_id BIGINT NOT NULL,
- display_name TEXT,
- -- Clobber based on tuple of user/device.
- CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
-);
-`
-
-const upsertDeviceKeysSQL = "" +
- "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
- " VALUES ($1, $2, $3, $4, $5, $6)" +
- " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
- " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
-
-const selectDeviceKeysSQL = "" +
- "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
-
-const selectBatchDeviceKeysSQL = "" +
- "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
-
-const selectBatchDeviceKeysWithEmptiesSQL = "" +
- "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
-
-const selectMaxStreamForUserSQL = "" +
- "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
-
-const countStreamIDsForUserSQL = "" +
- "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
-
-const deleteDeviceKeysSQL = "" +
- "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
-
-const deleteAllDeviceKeysSQL = "" +
- "DELETE FROM keyserver_device_keys WHERE user_id=$1"
-
-type deviceKeysStatements struct {
- db *sql.DB
- upsertDeviceKeysStmt *sql.Stmt
- selectDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysStmt *sql.Stmt
- selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
- selectMaxStreamForUserStmt *sql.Stmt
- countStreamIDsForUserStmt *sql.Stmt
- deleteDeviceKeysStmt *sql.Stmt
- deleteAllDeviceKeysStmt *sql.Stmt
-}
-
-func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
- s := &deviceKeysStatements{
- db: db,
- }
- _, err := db.Exec(deviceKeysSchema)
- if err != nil {
- return nil, err
- }
- if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
- return nil, err
- }
- if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
- return nil, err
- }
- if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
- return nil, err
- }
- if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil {
- return nil, err
- }
- if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
- return nil, err
- }
- return s, nil
-}
-
-func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
- for i, key := range keys {
- var keyJSONStr string
- var streamID int64
- var displayName sql.NullString
- err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
- if err != nil && err != sql.ErrNoRows {
- return err
- }
- // this will be '' when there is no device
- keys[i].Type = api.TypeDeviceKeyUpdate
- keys[i].KeyJSON = []byte(keyJSONStr)
- keys[i].StreamID = streamID
- if displayName.Valid {
- keys[i].DisplayName = displayName.String
- }
- }
- return nil
-}
-
-func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
- // nullable if there are no results
- var nullStream sql.NullInt64
- err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
- if err == sql.ErrNoRows {
- err = nil
- }
- if nullStream.Valid {
- streamID = nullStream.Int64
- }
- return
-}
-
-func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
- // nullable if there are no results
- var count sql.NullInt32
- err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count)
- if err != nil {
- return 0, err
- }
- if count.Valid {
- return int(count.Int32), nil
- }
- return 0, nil
-}
-
-func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
- for _, key := range keys {
- now := time.Now().Unix()
- _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
- ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
- )
- if err != nil {
- return err
- }
- }
- return nil
-}
-
-func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
- _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
- return err
-}
-
-func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
- _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
- return err
-}
-
-func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
- var stmt *sql.Stmt
- if includeEmpty {
- stmt = s.selectBatchDeviceKeysWithEmptiesStmt
- } else {
- stmt = s.selectBatchDeviceKeysStmt
- }
- rows, err := stmt.QueryContext(ctx, userID)
- if err != nil {
- return nil, err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
- deviceIDMap := make(map[string]bool)
- for _, d := range deviceIDs {
- deviceIDMap[d] = true
- }
- var result []api.DeviceMessage
- var displayName sql.NullString
- for rows.Next() {
- dk := api.DeviceMessage{
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- UserID: userID,
- },
- }
- if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
- return nil, err
- }
- if displayName.Valid {
- dk.DisplayName = displayName.String
- }
- // include the key if we want all keys (no device) or it was asked
- if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
- result = append(result, dk)
- }
- }
- return result, rows.Err()
-}