aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-09-14 16:39:38 +0100
committerGitHub <noreply@github.com>2020-09-14 16:39:38 +0100
commit8dc95062101b3906ffb83604e2abca02d9a3dd03 (patch)
treebdd5ec08d14c113346dd7d731e51403e83e5940b
parent913020e4b7ceeaf67b4c7ed8e0778126989846ef (diff)
Don't use more than 999 variables in SQLite querys. (#1425)
* Don't use more than 999 variables in SQLite querys. Solve this problem in a more general and reusable way. Also fix #1369 Add some unit tests. Signed-off-by: Henrik Sölver <henrik.solver@gmail.com> * Don't rely on testify for basic assertions * Readability improvements and linting Co-authored-by: Henrik Sölver <henrik.solver@gmail.com>
-rw-r--r--go.mod1
-rw-r--r--go.sum2
-rw-r--r--internal/sqlutil/sql.go45
-rw-r--r--internal/sqlutil/sqlutil_test.go173
-rw-r--r--serverkeyapi/storage/sqlite3/server_key_table.go67
5 files changed, 255 insertions, 33 deletions
diff --git a/go.mod b/go.mod
index f1cb3c9b..6b1c03b5 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,7 @@
module github.com/matrix-org/dendrite
require (
+ github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/Shopify/sarama v1.27.0
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect
github.com/gologme/log v1.2.0
diff --git a/go.sum b/go.sum
index ac7827d9..5c4f27a5 100644
--- a/go.sum
+++ b/go.sum
@@ -13,6 +13,8 @@ github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0 h1:p3puK8Sl2xK+2Fnn
github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
+github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
+github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc=
diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go
index 1d2825d5..90562ded 100644
--- a/internal/sqlutil/sql.go
+++ b/internal/sqlutil/sql.go
@@ -15,10 +15,14 @@
package sqlutil
import (
+ "context"
"database/sql"
"errors"
"fmt"
"runtime"
+ "strings"
+
+ "github.com/matrix-org/util"
)
// ErrUserExists is returned if a username already exists in the database.
@@ -107,3 +111,44 @@ func SQLiteDriverName() string {
}
return "sqlite3"
}
+
+func minOfInts(a, b int) int {
+ if a <= b {
+ return a
+ }
+ return b
+}
+
+// QueryProvider defines the interface for querys used by RunLimitedVariablesQuery.
+type QueryProvider interface {
+ QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
+}
+
+// SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement
+// SQLlite can handle. See https://www.sqlite.org/limits.html for more information.
+const SQLite3MaxVariables = 999
+
+// RunLimitedVariablesQuery split up a query with more variables than the used database can handle in multiple queries.
+func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvider, variables []interface{}, limit uint, rowHandler func(*sql.Rows) error) error {
+ var start int
+ for start < len(variables) {
+ n := minOfInts(len(variables)-start, int(limit))
+ nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1)
+ rows, err := qp.QueryContext(ctx, nextQuery, variables[start:start+n]...)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("QueryContext returned an error")
+ return err
+ }
+ err = rowHandler(rows)
+ if closeErr := rows.Close(); closeErr != nil {
+ util.GetLogger(ctx).WithError(closeErr).Error("RunLimitedVariablesQuery: failed to close rows")
+ return err
+ }
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("RunLimitedVariablesQuery: rowHandler returned error")
+ return err
+ }
+ start = start + n
+ }
+ return nil
+}
diff --git a/internal/sqlutil/sqlutil_test.go b/internal/sqlutil/sqlutil_test.go
new file mode 100644
index 00000000..79469cdd
--- /dev/null
+++ b/internal/sqlutil/sqlutil_test.go
@@ -0,0 +1,173 @@
+package sqlutil
+
+import (
+ "context"
+ "database/sql"
+ "reflect"
+ "testing"
+
+ sqlmock "github.com/DATA-DOG/go-sqlmock"
+)
+
+func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ assertNoError(t, err, "Failed to make DB")
+ limit := uint(4)
+
+ r := mock.NewRows([]string{"id"}).
+ AddRow(1).
+ AddRow(2).
+ AddRow(3)
+
+ mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r)
+ // nolint:goconst
+ q := "SELECT id WHERE id IN ($1)"
+ v := []int{1, 2, 3}
+ iKeyIDs := make([]interface{}, len(v))
+ for i, d := range v {
+ iKeyIDs[i] = d
+ }
+
+ ctx := context.Background()
+ var result = make([]int, 0)
+ err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
+ for rows.Next() {
+ var id int
+ err = rows.Scan(&id)
+ assertNoError(t, err, "rows.Scan returned an error")
+ result = append(result, id)
+ }
+ return nil
+ })
+ assertNoError(t, err, "Call returned an error")
+ if len(result) != len(v) {
+ t.Fatalf("Result should be 3 long")
+ }
+}
+
+func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ assertNoError(t, err, "Failed to make DB")
+ limit := uint(4)
+
+ r := mock.NewRows([]string{"id"}).
+ AddRow(1).
+ AddRow(2).
+ AddRow(3).
+ AddRow(4)
+
+ mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r)
+ // nolint:goconst
+ q := "SELECT id WHERE id IN ($1)"
+ v := []int{1, 2, 3, 4}
+ iKeyIDs := make([]interface{}, len(v))
+ for i, d := range v {
+ iKeyIDs[i] = d
+ }
+
+ ctx := context.Background()
+ var result = make([]int, 0)
+ err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
+ for rows.Next() {
+ var id int
+ err = rows.Scan(&id)
+ assertNoError(t, err, "rows.Scan returned an error")
+ result = append(result, id)
+ }
+ return nil
+ })
+ assertNoError(t, err, "Call returned an error")
+ if len(result) != len(v) {
+ t.Fatalf("Result should be 4 long")
+ }
+}
+
+func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ assertNoError(t, err, "Failed to make DB")
+ limit := uint(4)
+
+ r1 := mock.NewRows([]string{"id"}).
+ AddRow(1).
+ AddRow(2).
+ AddRow(3).
+ AddRow(4)
+
+ r2 := mock.NewRows([]string{"id"}).
+ AddRow(5)
+
+ mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r1)
+ mock.ExpectQuery(`SELECT id WHERE id IN \(\$1\)`).WillReturnRows(r2)
+ // nolint:goconst
+ q := "SELECT id WHERE id IN ($1)"
+ v := []int{1, 2, 3, 4, 5}
+ iKeyIDs := make([]interface{}, len(v))
+ for i, d := range v {
+ iKeyIDs[i] = d
+ }
+
+ ctx := context.Background()
+ var result = make([]int, 0)
+ err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
+ for rows.Next() {
+ var id int
+ err = rows.Scan(&id)
+ assertNoError(t, err, "rows.Scan returned an error")
+ result = append(result, id)
+ }
+ return nil
+ })
+ assertNoError(t, err, "Call returned an error")
+ if len(result) != len(v) {
+ t.Fatalf("Result should be 5 long")
+ }
+ if !reflect.DeepEqual(v, result) {
+ t.Fatalf("Result is not as expected: got %v want %v", v, result)
+ }
+}
+
+func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ assertNoError(t, err, "Failed to make DB")
+ limit := uint(4)
+
+ // adding a string ID should result in rows.Scan returning an error
+ r := mock.NewRows([]string{"id"}).
+ AddRow("hej").
+ AddRow(2).
+ AddRow(3)
+
+ mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r)
+ // nolint:goconst
+ q := "SELECT id WHERE id IN ($1)"
+ v := []int{-1, -2, 3}
+ iKeyIDs := make([]interface{}, len(v))
+ for i, d := range v {
+ iKeyIDs[i] = d
+ }
+
+ ctx := context.Background()
+ var result = make([]uint, 0)
+ err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
+ for rows.Next() {
+ var id uint
+ err = rows.Scan(&id)
+ if err != nil {
+ return err
+ }
+ result = append(result, id)
+ }
+ return nil
+ })
+ if err == nil {
+ t.Fatalf("Call did not return an error")
+ }
+}
+
+func assertNoError(t *testing.T, err error, msg string) {
+ t.Helper()
+ if err == nil {
+ return
+ }
+ t.Fatalf(msg)
+}
diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go
index f756ef5e..2484d636 100644
--- a/serverkeyapi/storage/sqlite3/server_key_table.go
+++ b/serverkeyapi/storage/sqlite3/server_key_table.go
@@ -18,9 +18,8 @@ package sqlite3
import (
"context"
"database/sql"
- "strings"
+ "fmt"
- "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -88,48 +87,50 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
- var nameAndKeyIDs []string
+ nameAndKeyIDs := make([]string, 0, len(requests))
for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
}
-
- query := strings.Replace(bulkSelectServerKeysSQL, "($1)", sqlutil.QueryVariadic(len(nameAndKeyIDs)), 1)
-
+ results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests))
iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
for i, v := range nameAndKeyIDs {
iKeyIDs[i] = v
}
- rows, err := s.db.QueryContext(ctx, query, iKeyIDs...)
+ err := sqlutil.RunLimitedVariablesQuery(
+ ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables,
+ func(rows *sql.Rows) error {
+ for rows.Next() {
+ var serverName string
+ var keyID string
+ var key string
+ var validUntilTS int64
+ var expiredTS int64
+ if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
+ return fmt.Errorf("bulkSelectServerKeys: %v", err)
+ }
+ r := gomatrixserverlib.PublicKeyLookupRequest{
+ ServerName: gomatrixserverlib.ServerName(serverName),
+ KeyID: gomatrixserverlib.KeyID(keyID),
+ }
+ vk := gomatrixserverlib.VerifyKey{}
+ err := vk.Key.Decode(key)
+ if err != nil {
+ return fmt.Errorf("bulkSelectServerKeys: %v", err)
+ }
+ results[r] = gomatrixserverlib.PublicKeyLookupResult{
+ VerifyKey: vk,
+ ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
+ ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
+ }
+ }
+ return nil
+ },
+ )
+
if err != nil {
return nil, err
}
- defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed")
- results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
- for rows.Next() {
- var serverName string
- var keyID string
- var key string
- var validUntilTS int64
- var expiredTS int64
- if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
- return nil, err
- }
- r := gomatrixserverlib.PublicKeyLookupRequest{
- ServerName: gomatrixserverlib.ServerName(serverName),
- KeyID: gomatrixserverlib.KeyID(keyID),
- }
- vk := gomatrixserverlib.VerifyKey{}
- err = vk.Key.Decode(key)
- if err != nil {
- return nil, err
- }
- results[r] = gomatrixserverlib.PublicKeyLookupResult{
- VerifyKey: vk,
- ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
- ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
- }
- }
return results, nil
}