aboutsummaryrefslogtreecommitdiff
path: root/internal/keydb/sqlite3/server_key_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/keydb/sqlite3/server_key_table.go')
-rw-r--r--internal/keydb/sqlite3/server_key_table.go152
1 files changed, 152 insertions, 0 deletions
diff --git a/internal/keydb/sqlite3/server_key_table.go b/internal/keydb/sqlite3/server_key_table.go
new file mode 100644
index 00000000..ae24a14d
--- /dev/null
+++ b/internal/keydb/sqlite3/server_key_table.go
@@ -0,0 +1,152 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "strings"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const serverKeysSchema = `
+-- A cache of signing keys downloaded from remote servers.
+CREATE TABLE IF NOT EXISTS keydb_server_keys (
+ -- The name of the matrix server the key is for.
+ server_name TEXT NOT NULL,
+ -- The ID of the server key.
+ server_key_id TEXT NOT NULL,
+ -- Combined server name and key ID separated by the ASCII unit separator
+ -- to make it easier to run bulk queries.
+ server_name_and_key_id TEXT NOT NULL,
+ -- When the key is valid until as a millisecond timestamp.
+ -- 0 if this is an expired key (in which case expired_ts will be non-zero)
+ valid_until_ts BIGINT NOT NULL,
+ -- When the key expired as a millisecond timestamp.
+ -- 0 if this is an active key (in which case valid_until_ts will be non-zero)
+ expired_ts BIGINT NOT NULL,
+ -- The base64-encoded public key.
+ server_key TEXT NOT NULL,
+ UNIQUE (server_name, server_key_id)
+);
+
+CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
+`
+
+const bulkSelectServerKeysSQL = "" +
+ "SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
+ " server_key FROM keydb_server_keys" +
+ " WHERE server_name_and_key_id IN ($1)"
+
+const upsertServerKeysSQL = "" +
+ "INSERT INTO keydb_server_keys (server_name, server_key_id," +
+ " server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)" +
+ " ON CONFLICT (server_name, server_key_id)" +
+ " DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
+
+type serverKeyStatements struct {
+ db *sql.DB
+ bulkSelectServerKeysStmt *sql.Stmt
+ upsertServerKeysStmt *sql.Stmt
+}
+
+func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ _, err = db.Exec(serverKeysSchema)
+ if err != nil {
+ return
+ }
+ if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerKeysSQL); err != nil {
+ return
+ }
+ if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *serverKeyStatements) bulkSelectServerKeys(
+ ctx context.Context,
+ requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
+) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
+ var nameAndKeyIDs []string
+ for request := range requests {
+ nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
+ }
+
+ query := strings.Replace(bulkSelectServerKeysSQL, "($1)", internal.QueryVariadic(len(nameAndKeyIDs)), 1)
+
+ iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
+ for i, v := range nameAndKeyIDs {
+ iKeyIDs[i] = v
+ }
+
+ rows, err := s.db.QueryContext(ctx, query, iKeyIDs...)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed")
+ results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
+ for rows.Next() {
+ var serverName string
+ var keyID string
+ var key string
+ var validUntilTS int64
+ var expiredTS int64
+ if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
+ return nil, err
+ }
+ r := gomatrixserverlib.PublicKeyLookupRequest{
+ ServerName: gomatrixserverlib.ServerName(serverName),
+ KeyID: gomatrixserverlib.KeyID(keyID),
+ }
+ vk := gomatrixserverlib.VerifyKey{}
+ err = vk.Key.Decode(key)
+ if err != nil {
+ return nil, err
+ }
+ results[r] = gomatrixserverlib.PublicKeyLookupResult{
+ VerifyKey: vk,
+ ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
+ ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
+ }
+ }
+ return results, nil
+}
+
+func (s *serverKeyStatements) upsertServerKeys(
+ ctx context.Context,
+ request gomatrixserverlib.PublicKeyLookupRequest,
+ key gomatrixserverlib.PublicKeyLookupResult,
+) error {
+ _, err := s.upsertServerKeysStmt.ExecContext(
+ ctx,
+ string(request.ServerName),
+ string(request.KeyID),
+ nameAndKeyID(request),
+ key.ValidUntilTS,
+ key.ExpiredTS,
+ key.Key.Encode(),
+ )
+ return err
+}
+
+func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
+ return string(request.ServerName) + "\x1F" + string(request.KeyID)
+}