aboutsummaryrefslogtreecommitdiff
path: root/test/db.go
diff options
context:
space:
mode:
authorkegsay <kegan@matrix.org>2022-04-08 10:12:30 +0100
committerGitHub <noreply@github.com>2022-04-08 10:12:30 +0100
commit7499147550110d24fa3a376bd811d9dd38971629 (patch)
tree335f11802f6cd391effddae9709b014ed1a17c58 /test/db.go
parent955e6eb307c78594fe9614f6a304dc521ba28d49 (diff)
Add test infrastructure code for dendrite unit/integ tests (#2331)
* Add test infrastructure code for dendrite unit/integ tests Start re-enabling some syncapi storage tests in the process. * Linting * Add postgres service to unit tests * dendrite not syncv3 * Skip test which doesn't work * Linting * Add `jetstream.PrepareForTests` Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
Diffstat (limited to 'test/db.go')
-rw-r--r--test/db.go127
1 files changed, 127 insertions, 0 deletions
diff --git a/test/db.go b/test/db.go
new file mode 100644
index 00000000..9deec0a8
--- /dev/null
+++ b/test/db.go
@@ -0,0 +1,127 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+import (
+ "database/sql"
+ "fmt"
+ "os"
+ "os/exec"
+ "os/user"
+ "testing"
+)
+
+type DBType int
+
+var DBTypeSQLite DBType = 1
+var DBTypePostgres DBType = 2
+
+var Quiet = false
+
+func createLocalDB(dbName string) string {
+ if !Quiet {
+ fmt.Println("Note: tests require a postgres install accessible to the current user")
+ }
+ createDB := exec.Command("createdb", dbName)
+ if !Quiet {
+ createDB.Stdout = os.Stdout
+ createDB.Stderr = os.Stderr
+ }
+ err := createDB.Run()
+ if err != nil && !Quiet {
+ fmt.Println("createLocalDB returned error:", err)
+ }
+ return dbName
+}
+
+func currentUser() string {
+ user, err := user.Current()
+ if err != nil {
+ if !Quiet {
+ fmt.Println("cannot get current user: ", err)
+ }
+ os.Exit(2)
+ }
+ return user.Username
+}
+
+// Prepare a sqlite or postgres connection string for testing.
+// Returns the connection string to use and a close function which must be called when the test finishes.
+// Calling this function twice will return the same database, which will have data from previous tests
+// unless close() is called.
+// TODO: namespace for concurrent package tests
+func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) {
+ if dbType == DBTypeSQLite {
+ dbname := "dendrite_test.db"
+ return fmt.Sprintf("file:%s", dbname), func() {
+ err := os.Remove(dbname)
+ if err != nil {
+ t.Fatalf("failed to cleanup sqlite db '%s': %s", dbname, err)
+ }
+ }
+ }
+
+ // Required vars: user and db
+ // We'll try to infer from the local env if they are missing
+ user := os.Getenv("POSTGRES_USER")
+ if user == "" {
+ user = currentUser()
+ }
+ dbName := os.Getenv("POSTGRES_DB")
+ if dbName == "" {
+ dbName = createLocalDB("dendrite_test")
+ }
+ connStr = fmt.Sprintf(
+ "user=%s dbname=%s sslmode=disable",
+ user, dbName,
+ )
+ // optional vars, used in CI
+ password := os.Getenv("POSTGRES_PASSWORD")
+ if password != "" {
+ connStr += fmt.Sprintf(" password=%s", password)
+ }
+ host := os.Getenv("POSTGRES_HOST")
+ if host != "" {
+ connStr += fmt.Sprintf(" host=%s", host)
+ }
+
+ return connStr, func() {
+ // Drop all tables on the database to get a fresh instance
+ db, err := sql.Open("postgres", connStr)
+ if err != nil {
+ t.Fatalf("failed to connect to postgres db '%s': %s", connStr, err)
+ }
+ _, err = db.Exec(`DROP SCHEMA public CASCADE;
+ CREATE SCHEMA public;`)
+ if err != nil {
+ t.Fatalf("failed to cleanup postgres db '%s': %s", connStr, err)
+ }
+ _ = db.Close()
+ }
+}
+
+// Creates subtests with each known DBType
+func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) {
+ dbs := map[string]DBType{
+ "postgres": DBTypePostgres,
+ "sqlite": DBTypeSQLite,
+ }
+ for dbName, dbType := range dbs {
+ dbt := dbType
+ t.Run(dbName, func(tt *testing.T) {
+ testFn(tt, dbt)
+ })
+ }
+}