aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--federationapi/federationapi_keys_test.go10
-rw-r--r--federationapi/federationapi_test.go8
-rw-r--r--internal/sqlutil/migrate.go62
-rw-r--r--internal/sqlutil/migrate_test.go42
4 files changed, 82 insertions, 40 deletions
diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go
index 4469da35..85cc43aa 100644
--- a/federationapi/federationapi_keys_test.go
+++ b/federationapi/federationapi_keys_test.go
@@ -12,12 +12,13 @@ import (
"testing"
"time"
+ "github.com/matrix-org/gomatrixserverlib"
+
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/gomatrixserverlib"
)
type server struct {
@@ -86,7 +87,12 @@ func TestMain(m *testing.M) {
cfg.Global.JetStream.StoragePath = config.Path(d)
cfg.Global.KeyID = serverKeyID
cfg.Global.KeyValidityPeriod = s.validity
- cfg.FederationAPI.Database.ConnectionString = config.DataSource("file::memory:")
+ f, err := os.CreateTemp(d, "federation_keys_test*.db")
+ if err != nil {
+ return -1
+ }
+ defer f.Close()
+ cfg.FederationAPI.Database.ConnectionString = config.DataSource("file:" + f.Name())
s.config = &cfg.FederationAPI
// Create a transport which redirects federation requests to
diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go
index 15f7a684..e923143a 100644
--- a/federationapi/federationapi_test.go
+++ b/federationapi/federationapi_test.go
@@ -10,6 +10,10 @@ import (
"testing"
"time"
+ "github.com/matrix-org/gomatrix"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
+
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/internal"
@@ -20,9 +24,6 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
- "github.com/matrix-org/gomatrix"
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/nats-io/nats.go"
)
type fedRoomserverAPI struct {
@@ -271,7 +272,6 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
cfg.Global.ServerName = gomatrixserverlib.ServerName("localhost")
cfg.Global.PrivateKey = privKey
cfg.Global.JetStream.InMemory = true
- cfg.FederationAPI.Database.ConnectionString = config.DataSource("file::memory:")
base := base.NewBaseDendrite(cfg, "Monolith")
keyRing := &test.NopJSONVerifier{}
// TODO: This is pretty fragile, as if anything calls anything on these nils this test will break.
diff --git a/internal/sqlutil/migrate.go b/internal/sqlutil/migrate.go
index b6a8b1f2..a66a7582 100644
--- a/internal/sqlutil/migrate.go
+++ b/internal/sqlutil/migrate.go
@@ -49,12 +49,13 @@ type Migration struct {
Down func(ctx context.Context, txn *sql.Tx) error
}
-// Migrator
+// Migrator contains fields required to run migrations.
type Migrator struct {
db *sql.DB
migrations []Migration
knownMigrations map[string]struct{}
mutex *sync.Mutex
+ insertStmt *sql.Stmt
}
// NewMigrator creates a new DB migrator.
@@ -82,35 +83,26 @@ func (m *Migrator) AddMigrations(migrations ...Migration) {
// Up executes all migrations in order they were added.
func (m *Migrator) Up(ctx context.Context) error {
- var (
- err error
- dendriteVersion = internal.VersionString()
- )
// ensure there is a table for known migrations
executedMigrations, err := m.ExecutedMigrations(ctx)
if err != nil {
return fmt.Errorf("unable to create/get migrations: %w", err)
}
-
+ // ensure we close the insert statement, as it's not needed anymore
+ defer m.close()
return WithTransaction(m.db, func(txn *sql.Tx) error {
for i := range m.migrations {
- now := time.Now().UTC().Format(time.RFC3339)
migration := m.migrations[i]
// Skip migration if it was already executed
if _, ok := executedMigrations[migration.Version]; ok {
continue
}
logrus.Debugf("Executing database migration '%s'", migration.Version)
- err = migration.Up(ctx, txn)
- if err != nil {
+
+ if err = migration.Up(ctx, txn); err != nil {
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
}
- _, err = txn.ExecContext(ctx, insertVersionSQL,
- migration.Version,
- now,
- dendriteVersion,
- )
- if err != nil {
+ if err = m.insertMigration(ctx, txn, migration.Version); err != nil {
return fmt.Errorf("unable to insert executed migrations: %w", err)
}
}
@@ -118,6 +110,23 @@ func (m *Migrator) Up(ctx context.Context) error {
})
}
+func (m *Migrator) insertMigration(ctx context.Context, txn *sql.Tx, migrationName string) error {
+ if m.insertStmt == nil {
+ stmt, err := m.db.Prepare(insertVersionSQL)
+ if err != nil {
+ return fmt.Errorf("unable to prepare insert statement: %w", err)
+ }
+ m.insertStmt = stmt
+ }
+ stmt := TxStmtContext(ctx, txn, m.insertStmt)
+ _, err := stmt.ExecContext(ctx,
+ migrationName,
+ time.Now().Format(time.RFC3339),
+ internal.VersionString(),
+ )
+ return err
+}
+
// ExecutedMigrations returns a map with already executed migrations in addition to creating the
// migrations table, if it doesn't exist.
func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{}, error) {
@@ -146,19 +155,20 @@ func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{},
// inserts a migration given their name to the database.
// This should only be used when manually inserting migrations.
func InsertMigration(ctx context.Context, db *sql.DB, migrationName string) error {
- _, err := db.ExecContext(ctx, createDBMigrationsSQL)
+ m := NewMigrator(db)
+ defer m.close()
+ existingMigrations, err := m.ExecutedMigrations(ctx)
if err != nil {
- return fmt.Errorf("unable to create db_migrations: %w", err)
+ return err
}
- _, err = db.ExecContext(ctx, insertVersionSQL,
- migrationName,
- time.Now().Format(time.RFC3339),
- internal.VersionString(),
- )
- // If the migration was already executed, we'll get a unique constraint error,
- // return nil instead, to avoid unnecessary logging.
- if IsUniqueConstraintViolationErr(err) {
+ if _, ok := existingMigrations[migrationName]; ok {
return nil
}
- return err
+ return m.insertMigration(ctx, nil, migrationName)
+}
+
+func (m *Migrator) close() {
+ if m.insertStmt != nil {
+ internal.CloseAndLogIfError(context.Background(), m.insertStmt, "unable to close insert statement")
+ }
}
diff --git a/internal/sqlutil/migrate_test.go b/internal/sqlutil/migrate_test.go
index d8bcae19..5116237a 100644
--- a/internal/sqlutil/migrate_test.go
+++ b/internal/sqlutil/migrate_test.go
@@ -7,9 +7,10 @@ import (
"reflect"
"testing"
+ _ "github.com/mattn/go-sqlite3"
+
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/test"
- _ "github.com/mattn/go-sqlite3"
)
var dummyMigrations = []sqlutil.Migration{
@@ -81,11 +82,12 @@ func Test_migrations_Up(t *testing.T) {
}
ctx := context.Background()
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- conStr, close := test.PrepareDBConnectionString(t, dbType)
- defer close()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ conStr, close := test.PrepareDBConnectionString(t, dbType)
+ defer close()
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
driverName := "sqlite3"
if dbType == test.DBTypePostgres {
driverName = "postgres"
@@ -107,6 +109,30 @@ func Test_migrations_Up(t *testing.T) {
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
}
})
- })
- }
+ }
+ })
+}
+
+func Test_insertMigration(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ conStr, close := test.PrepareDBConnectionString(t, dbType)
+ defer close()
+ driverName := "sqlite3"
+ if dbType == test.DBTypePostgres {
+ driverName = "postgres"
+ }
+
+ db, err := sql.Open(driverName, conStr)
+ if err != nil {
+ t.Errorf("unable to open database: %v", err)
+ }
+
+ if err := sqlutil.InsertMigration(context.Background(), db, "testing"); err != nil {
+ t.Fatalf("unable to insert migration: %s", err)
+ }
+ // Second insert should not return an error, as it was already executed.
+ if err := sqlutil.InsertMigration(context.Background(), db, "testing"); err != nil {
+ t.Fatalf("unable to insert migration: %s", err)
+ }
+ })
}