aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/basecomponent/base.go7
-rw-r--r--common/caching/immutablecache.go3
-rw-r--r--common/caching/immutableinmemorylru.go28
-rw-r--r--common/keydb/cache/keydb.go69
-rw-r--r--common/keydb/postgres/keydb.go2
-rw-r--r--common/keydb/sqlite3/keydb.go2
-rw-r--r--common/keydb/sqlite3/server_key_table.go24
7 files changed, 108 insertions, 27 deletions
diff --git a/common/basecomponent/base.go b/common/basecomponent/base.go
index cb04a308..4342e25a 100644
--- a/common/basecomponent/base.go
+++ b/common/basecomponent/base.go
@@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/common/caching"
"github.com/matrix-org/dendrite/common/keydb"
+ "github.com/matrix-org/dendrite/common/keydb/cache"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/naffka"
@@ -186,7 +187,11 @@ func (b *BaseDendrite) CreateKeyDB() keydb.Database {
logrus.WithError(err).Panicf("failed to connect to keys db")
}
- return db
+ cachedDB, err := cache.NewKeyDatabase(db, b.ImmutableCache)
+ if err != nil {
+ logrus.WithError(err).Panicf("failed to create key cache wrapper")
+ }
+ return cachedDB
}
// CreateFederationClient creates a new federation client. Should only be called
diff --git a/common/caching/immutablecache.go b/common/caching/immutablecache.go
index 9620667a..362e4349 100644
--- a/common/caching/immutablecache.go
+++ b/common/caching/immutablecache.go
@@ -4,9 +4,12 @@ import "github.com/matrix-org/gomatrixserverlib"
const (
RoomVersionMaxCacheEntries = 128
+ ServerKeysMaxCacheEntries = 128
)
type ImmutableCache interface {
GetRoomVersion(roomId string) (gomatrixserverlib.RoomVersion, bool)
StoreRoomVersion(roomId string, roomVersion gomatrixserverlib.RoomVersion)
+ GetServerKey(request gomatrixserverlib.PublicKeyLookupRequest) (gomatrixserverlib.PublicKeyLookupResult, bool)
+ StoreServerKey(request gomatrixserverlib.PublicKeyLookupRequest, response gomatrixserverlib.PublicKeyLookupResult)
}
diff --git a/common/caching/immutableinmemorylru.go b/common/caching/immutableinmemorylru.go
index 3e8f4aad..6d2a785f 100644
--- a/common/caching/immutableinmemorylru.go
+++ b/common/caching/immutableinmemorylru.go
@@ -9,6 +9,7 @@ import (
type ImmutableInMemoryLRUCache struct {
roomVersions *lru.Cache
+ serverKeys *lru.Cache
}
func NewImmutableInMemoryLRUCache() (*ImmutableInMemoryLRUCache, error) {
@@ -16,8 +17,13 @@ func NewImmutableInMemoryLRUCache() (*ImmutableInMemoryLRUCache, error) {
if rvErr != nil {
return nil, rvErr
}
+ serverKeysCache, rvErr := lru.New(ServerKeysMaxCacheEntries)
+ if rvErr != nil {
+ return nil, rvErr
+ }
return &ImmutableInMemoryLRUCache{
roomVersions: roomVersionCache,
+ serverKeys: serverKeysCache,
}, nil
}
@@ -41,3 +47,25 @@ func (c *ImmutableInMemoryLRUCache) StoreRoomVersion(roomID string, roomVersion
checkForInvalidMutation(c.roomVersions, roomID, roomVersion)
c.roomVersions.Add(roomID, roomVersion)
}
+
+func (c *ImmutableInMemoryLRUCache) GetServerKey(
+ request gomatrixserverlib.PublicKeyLookupRequest,
+) (gomatrixserverlib.PublicKeyLookupResult, bool) {
+ key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID)
+ val, found := c.serverKeys.Get(key)
+ if found && val != nil {
+ if keyLookupResult, ok := val.(gomatrixserverlib.PublicKeyLookupResult); ok {
+ return keyLookupResult, true
+ }
+ }
+ return gomatrixserverlib.PublicKeyLookupResult{}, false
+}
+
+func (c *ImmutableInMemoryLRUCache) StoreServerKey(
+ request gomatrixserverlib.PublicKeyLookupRequest,
+ response gomatrixserverlib.PublicKeyLookupResult,
+) {
+ key := fmt.Sprintf("%s/%s", request.ServerName, request.KeyID)
+ checkForInvalidMutation(c.roomVersions, key, response)
+ c.serverKeys.Add(request, response)
+}
diff --git a/common/keydb/cache/keydb.go b/common/keydb/cache/keydb.go
new file mode 100644
index 00000000..ae929fa4
--- /dev/null
+++ b/common/keydb/cache/keydb.go
@@ -0,0 +1,69 @@
+package cache
+
+import (
+ "context"
+ "errors"
+
+ "github.com/matrix-org/dendrite/common/caching"
+ "github.com/matrix-org/dendrite/common/keydb"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// A Database implements gomatrixserverlib.KeyDatabase and is used to store
+// the public keys for other matrix servers.
+type KeyDatabase struct {
+ inner keydb.Database
+ cache caching.ImmutableCache
+}
+
+func NewKeyDatabase(inner keydb.Database, cache caching.ImmutableCache) (*KeyDatabase, error) {
+ if inner == nil {
+ return nil, errors.New("inner database can't be nil")
+ }
+ if cache == nil {
+ return nil, errors.New("cache can't be nil")
+ }
+ return &KeyDatabase{
+ inner: inner,
+ cache: cache,
+ }, nil
+}
+
+// FetcherName implements KeyFetcher
+func (d KeyDatabase) FetcherName() string {
+ return "InMemoryKeyCache"
+}
+
+// FetchKeys implements gomatrixserverlib.KeyDatabase
+func (d *KeyDatabase) FetchKeys(
+ ctx context.Context,
+ requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
+) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
+ results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult)
+ for req := range requests {
+ if res, cached := d.cache.GetServerKey(req); cached {
+ results[req] = res
+ delete(requests, req)
+ }
+ }
+ fromDB, err := d.inner.FetchKeys(ctx, requests)
+ if err != nil {
+ return results, err
+ }
+ for req, res := range fromDB {
+ results[req] = res
+ d.cache.StoreServerKey(req, res)
+ }
+ return results, nil
+}
+
+// StoreKeys implements gomatrixserverlib.KeyDatabase
+func (d *KeyDatabase) StoreKeys(
+ ctx context.Context,
+ keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
+) error {
+ for req, res := range keyMap {
+ d.cache.StoreServerKey(req, res)
+ }
+ return d.inner.StoreKeys(ctx, keyMap)
+}
diff --git a/common/keydb/postgres/keydb.go b/common/keydb/postgres/keydb.go
index 6149d877..706ca005 100644
--- a/common/keydb/postgres/keydb.go
+++ b/common/keydb/postgres/keydb.go
@@ -79,7 +79,7 @@ func NewDatabase(
// FetcherName implements KeyFetcher
func (d Database) FetcherName() string {
- return "KeyDatabase"
+ return "PostgresKeyDatabase"
}
// FetchKeys implements gomatrixserverlib.KeyDatabase
diff --git a/common/keydb/sqlite3/keydb.go b/common/keydb/sqlite3/keydb.go
index 1405836a..94a32e29 100644
--- a/common/keydb/sqlite3/keydb.go
+++ b/common/keydb/sqlite3/keydb.go
@@ -80,7 +80,7 @@ func NewDatabase(
// FetcherName implements KeyFetcher
func (d Database) FetcherName() string {
- return "KeyDatabase"
+ return "SqliteKeyDatabase"
}
// FetchKeys implements gomatrixserverlib.KeyDatabase
diff --git a/common/keydb/sqlite3/server_key_table.go b/common/keydb/sqlite3/server_key_table.go
index ba1cc060..883d3cd0 100644
--- a/common/keydb/sqlite3/server_key_table.go
+++ b/common/keydb/sqlite3/server_key_table.go
@@ -20,10 +20,8 @@ import (
"database/sql"
"strings"
- lru "github.com/hashicorp/golang-lru"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/util"
)
const serverKeysSchema = `
@@ -66,16 +64,10 @@ type serverKeyStatements struct {
db *sql.DB
bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt
-
- cache *lru.Cache // nameAndKeyID => gomatrixserverlib.PublicKeyLookupResult
}
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
s.db = db
- s.cache, err = lru.New(64)
- if err != nil {
- return
- }
_, err = db.Exec(serverKeysSchema)
if err != nil {
return
@@ -98,21 +90,6 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
}
- // If we can satisfy all of the requests from the cache, do so. TODO: Allow partial matches with merges.
- cacheResults := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
- for request := range requests {
- r, ok := s.cache.Get(nameAndKeyID(request))
- if !ok {
- break
- }
- cacheResult := r.(gomatrixserverlib.PublicKeyLookupResult)
- cacheResults[request] = cacheResult
- }
- if len(cacheResults) == len(requests) {
- util.GetLogger(ctx).Infof("KeyDB cache hit for %d keys", len(cacheResults))
- return cacheResults, nil
- }
-
query := strings.Replace(bulkSelectServerKeysSQL, "($1)", common.QueryVariadic(len(nameAndKeyIDs)), 1)
iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
@@ -158,7 +135,6 @@ func (s *serverKeyStatements) upsertServerKeys(
request gomatrixserverlib.PublicKeyLookupRequest,
key gomatrixserverlib.PublicKeyLookupResult,
) error {
- s.cache.Add(nameAndKeyID(request), key)
_, err := s.upsertServerKeysStmt.ExecContext(
ctx,
string(request.ServerName),