diff options
author | Kegsay <kegan@matrix.org> | 2020-09-14 16:39:38 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-14 16:39:38 +0100 |
commit | 8dc95062101b3906ffb83604e2abca02d9a3dd03 (patch) | |
tree | bdd5ec08d14c113346dd7d731e51403e83e5940b /internal | |
parent | 913020e4b7ceeaf67b4c7ed8e0778126989846ef (diff) |
Don't use more than 999 variables in SQLite querys. (#1425)
* Don't use more than 999 variables in SQLite querys.
Solve this problem in a more general and reusable way.
Also fix #1369
Add some unit tests.
Signed-off-by: Henrik Sölver <henrik.solver@gmail.com>
* Don't rely on testify for basic assertions
* Readability improvements and linting
Co-authored-by: Henrik Sölver <henrik.solver@gmail.com>
Diffstat (limited to 'internal')
-rw-r--r-- | internal/sqlutil/sql.go | 45 | ||||
-rw-r--r-- | internal/sqlutil/sqlutil_test.go | 173 |
2 files changed, 218 insertions, 0 deletions
diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 1d2825d5..90562ded 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -15,10 +15,14 @@ package sqlutil import ( + "context" "database/sql" "errors" "fmt" "runtime" + "strings" + + "github.com/matrix-org/util" ) // ErrUserExists is returned if a username already exists in the database. @@ -107,3 +111,44 @@ func SQLiteDriverName() string { } return "sqlite3" } + +func minOfInts(a, b int) int { + if a <= b { + return a + } + return b +} + +// QueryProvider defines the interface for querys used by RunLimitedVariablesQuery. +type QueryProvider interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) +} + +// SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement +// SQLlite can handle. See https://www.sqlite.org/limits.html for more information. +const SQLite3MaxVariables = 999 + +// RunLimitedVariablesQuery split up a query with more variables than the used database can handle in multiple queries. +func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvider, variables []interface{}, limit uint, rowHandler func(*sql.Rows) error) error { + var start int + for start < len(variables) { + n := minOfInts(len(variables)-start, int(limit)) + nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1) + rows, err := qp.QueryContext(ctx, nextQuery, variables[start:start+n]...) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("QueryContext returned an error") + return err + } + err = rowHandler(rows) + if closeErr := rows.Close(); closeErr != nil { + util.GetLogger(ctx).WithError(closeErr).Error("RunLimitedVariablesQuery: failed to close rows") + return err + } + if err != nil { + util.GetLogger(ctx).WithError(err).Error("RunLimitedVariablesQuery: rowHandler returned error") + return err + } + start = start + n + } + return nil +} diff --git a/internal/sqlutil/sqlutil_test.go b/internal/sqlutil/sqlutil_test.go new file mode 100644 index 00000000..79469cdd --- /dev/null +++ b/internal/sqlutil/sqlutil_test.go @@ -0,0 +1,173 @@ +package sqlutil + +import ( + "context" + "database/sql" + "reflect" + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" +) + +func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + limit := uint(4) + + r := mock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2). + AddRow(3) + + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r) + // nolint:goconst + q := "SELECT id WHERE id IN ($1)" + v := []int{1, 2, 3} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]int, 0) + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { + for rows.Next() { + var id int + err = rows.Scan(&id) + assertNoError(t, err, "rows.Scan returned an error") + result = append(result, id) + } + return nil + }) + assertNoError(t, err, "Call returned an error") + if len(result) != len(v) { + t.Fatalf("Result should be 3 long") + } +} + +func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + limit := uint(4) + + r := mock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2). + AddRow(3). + AddRow(4) + + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r) + // nolint:goconst + q := "SELECT id WHERE id IN ($1)" + v := []int{1, 2, 3, 4} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]int, 0) + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { + for rows.Next() { + var id int + err = rows.Scan(&id) + assertNoError(t, err, "rows.Scan returned an error") + result = append(result, id) + } + return nil + }) + assertNoError(t, err, "Call returned an error") + if len(result) != len(v) { + t.Fatalf("Result should be 4 long") + } +} + +func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + limit := uint(4) + + r1 := mock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2). + AddRow(3). + AddRow(4) + + r2 := mock.NewRows([]string{"id"}). + AddRow(5) + + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r1) + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1\)`).WillReturnRows(r2) + // nolint:goconst + q := "SELECT id WHERE id IN ($1)" + v := []int{1, 2, 3, 4, 5} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]int, 0) + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { + for rows.Next() { + var id int + err = rows.Scan(&id) + assertNoError(t, err, "rows.Scan returned an error") + result = append(result, id) + } + return nil + }) + assertNoError(t, err, "Call returned an error") + if len(result) != len(v) { + t.Fatalf("Result should be 5 long") + } + if !reflect.DeepEqual(v, result) { + t.Fatalf("Result is not as expected: got %v want %v", v, result) + } +} + +func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + limit := uint(4) + + // adding a string ID should result in rows.Scan returning an error + r := mock.NewRows([]string{"id"}). + AddRow("hej"). + AddRow(2). + AddRow(3) + + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r) + // nolint:goconst + q := "SELECT id WHERE id IN ($1)" + v := []int{-1, -2, 3} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]uint, 0) + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { + for rows.Next() { + var id uint + err = rows.Scan(&id) + if err != nil { + return err + } + result = append(result, id) + } + return nil + }) + if err == nil { + t.Fatalf("Call did not return an error") + } +} + +func assertNoError(t *testing.T, err error, msg string) { + t.Helper() + if err == nil { + return + } + t.Fatalf(msg) +} |