aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/db.go58
1 files changed, 50 insertions, 8 deletions
diff --git a/test/db.go b/test/db.go
index 674fdf5c..6412feaa 100644
--- a/test/db.go
+++ b/test/db.go
@@ -15,12 +15,16 @@
package test
import (
+ "crypto/sha256"
"database/sql"
+ "encoding/hex"
"fmt"
"os"
"os/exec"
"os/user"
"testing"
+
+ "github.com/lib/pq"
)
type DBType int
@@ -30,7 +34,7 @@ var DBTypePostgres DBType = 2
var Quiet = false
-func createLocalDB(dbName string) string {
+func createLocalDB(dbName string) {
if !Quiet {
fmt.Println("Note: tests require a postgres install accessible to the current user")
}
@@ -43,7 +47,29 @@ func createLocalDB(dbName string) string {
if err != nil && !Quiet {
fmt.Println("createLocalDB returned error:", err)
}
- return dbName
+}
+
+func createRemoteDB(t *testing.T, dbName, user, connStr string) {
+ db, err := sql.Open("postgres", connStr+" dbname=postgres")
+ if err != nil {
+ t.Fatalf("failed to open postgres conn with connstr=%s : %s", connStr, err)
+ }
+ _, err = db.Exec(fmt.Sprintf(`CREATE DATABASE %s;`, dbName))
+ if err != nil {
+ pqErr, ok := err.(*pq.Error)
+ if !ok {
+ t.Fatalf("failed to CREATE DATABASE: %s", err)
+ }
+ // we ignore duplicate database error as we expect this
+ if pqErr.Code != "42P04" {
+ t.Fatalf("failed to CREATE DATABASE with code=%s msg=%s", pqErr.Code, pqErr.Message)
+ }
+ }
+ _, err = db.Exec(fmt.Sprintf(`GRANT ALL PRIVILEGES ON DATABASE %s TO %s`, dbName, user))
+ if err != nil {
+ t.Fatalf("failed to GRANT: %s", err)
+ }
+ _ = db.Close()
}
func currentUser() string {
@@ -64,6 +90,7 @@ func currentUser() string {
// TODO: namespace for concurrent package tests
func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) {
if dbType == DBTypeSQLite {
+ // this will be made in the current working directory which namespaces concurrent package runs correctly
dbname := "dendrite_test.db"
return fmt.Sprintf("file:%s", dbname), func() {
err := os.Remove(dbname)
@@ -79,13 +106,9 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo
if user == "" {
user = currentUser()
}
- dbName := os.Getenv("POSTGRES_DB")
- if dbName == "" {
- dbName = createLocalDB("dendrite_test")
- }
connStr = fmt.Sprintf(
- "user=%s dbname=%s sslmode=disable",
- user, dbName,
+ "user=%s sslmode=disable",
+ user,
)
// optional vars, used in CI
password := os.Getenv("POSTGRES_PASSWORD")
@@ -97,6 +120,25 @@ func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, clo
connStr += fmt.Sprintf(" host=%s", host)
}
+ // superuser database
+ postgresDB := os.Getenv("POSTGRES_DB")
+ // we cannot use 'dendrite_test' here else 2x concurrently running packages will try to use the same db.
+ // instead, hash the current working directory, snaffle the first 16 bytes and append that to dendrite_test
+ // and use that as the unique db name. We do this because packages are per-directory hence by hashing the
+ // working (test) directory we ensure we get a consistent hash and don't hash against concurrent packages.
+ wd, err := os.Getwd()
+ if err != nil {
+ t.Fatalf("cannot get working directory: %s", err)
+ }
+ hash := sha256.Sum256([]byte(wd))
+ dbName := fmt.Sprintf("dendrite_test_%s", hex.EncodeToString(hash[:16]))
+ if postgresDB == "" { // local server, use createdb
+ createLocalDB(dbName)
+ } else { // remote server, shell into the postgres user and CREATE DATABASE
+ createRemoteDB(t, dbName, user, connStr)
+ }
+ connStr += fmt.Sprintf(" dbname=%s", dbName)
+
return connStr, func() {
// Drop all tables on the database to get a fresh instance
db, err := sql.Open("postgres", connStr)