diff options
Diffstat (limited to 'serverkeyapi/storage/sqlite3/server_key_table.go')
-rw-r--r-- | serverkeyapi/storage/sqlite3/server_key_table.go | 67 |
1 files changed, 34 insertions, 33 deletions
diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go index f756ef5e..2484d636 100644 --- a/serverkeyapi/storage/sqlite3/server_key_table.go +++ b/serverkeyapi/storage/sqlite3/server_key_table.go @@ -18,9 +18,8 @@ package sqlite3 import ( "context" "database/sql" - "strings" + "fmt" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -88,48 +87,50 @@ func (s *serverKeyStatements) bulkSelectServerKeys( ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { - var nameAndKeyIDs []string + nameAndKeyIDs := make([]string, 0, len(requests)) for request := range requests { nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) } - - query := strings.Replace(bulkSelectServerKeysSQL, "($1)", sqlutil.QueryVariadic(len(nameAndKeyIDs)), 1) - + results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests)) iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) for i, v := range nameAndKeyIDs { iKeyIDs[i] = v } - rows, err := s.db.QueryContext(ctx, query, iKeyIDs...) + err := sqlutil.RunLimitedVariablesQuery( + ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables, + func(rows *sql.Rows) error { + for rows.Next() { + var serverName string + var keyID string + var key string + var validUntilTS int64 + var expiredTS int64 + if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { + return fmt.Errorf("bulkSelectServerKeys: %v", err) + } + r := gomatrixserverlib.PublicKeyLookupRequest{ + ServerName: gomatrixserverlib.ServerName(serverName), + KeyID: gomatrixserverlib.KeyID(keyID), + } + vk := gomatrixserverlib.VerifyKey{} + err := vk.Key.Decode(key) + if err != nil { + return fmt.Errorf("bulkSelectServerKeys: %v", err) + } + results[r] = gomatrixserverlib.PublicKeyLookupResult{ + VerifyKey: vk, + ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), + ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), + } + } + return nil + }, + ) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed") - results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} - for rows.Next() { - var serverName string - var keyID string - var key string - var validUntilTS int64 - var expiredTS int64 - if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { - return nil, err - } - r := gomatrixserverlib.PublicKeyLookupRequest{ - ServerName: gomatrixserverlib.ServerName(serverName), - KeyID: gomatrixserverlib.KeyID(keyID), - } - vk := gomatrixserverlib.VerifyKey{} - err = vk.Key.Decode(key) - if err != nil { - return nil, err - } - results[r] = gomatrixserverlib.PublicKeyLookupResult{ - VerifyKey: vk, - ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), - ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), - } - } return results, nil } |