aboutsummaryrefslogtreecommitdiff
path: root/internal/sqlutil/migrate.go
blob: 735fb4927cb1ba70827b85e5823cf7ce8e15b72a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sqlutil

import (
	"context"
	"database/sql"
	"fmt"
	"sync"
	"time"

	"github.com/sirupsen/logrus"

	"github.com/matrix-org/dendrite/internal"
)

const createDBMigrationsSQL = "" +
	"CREATE TABLE IF NOT EXISTS db_migrations (" +
	" version TEXT PRIMARY KEY NOT NULL," +
	" time TEXT NOT NULL," +
	" dendrite_version TEXT NOT NULL" +
	");"

const insertVersionSQL = "" +
	"INSERT INTO db_migrations (version, time, dendrite_version)" +
	" VALUES ($1, $2, $3)"

const selectDBMigrationsSQL = "SELECT version FROM db_migrations"

// Migration defines a migration to be run.
type Migration struct {
	// Version is a simple description/name of this migration.
	Version string
	// Up defines the function to execute for an upgrade.
	Up func(ctx context.Context, txn *sql.Tx) error
	// Down defines the function to execute for a downgrade (not implemented yet).
	Down func(ctx context.Context, txn *sql.Tx) error
}

// Migrator contains fields required to run migrations.
type Migrator struct {
	db              *sql.DB
	migrations      []Migration
	knownMigrations map[string]struct{}
	mutex           *sync.Mutex
	insertStmt      *sql.Stmt
}

// NewMigrator creates a new DB migrator.
func NewMigrator(db *sql.DB) *Migrator {
	return &Migrator{
		db:              db,
		migrations:      []Migration{},
		knownMigrations: make(map[string]struct{}),
		mutex:           &sync.Mutex{},
	}
}

// AddMigrations appends migrations to the list of migrations. Migrations are executed
// in the order they are added to the list. De-duplicates migrations using their Version field.
func (m *Migrator) AddMigrations(migrations ...Migration) {
	m.mutex.Lock()
	defer m.mutex.Unlock()
	for _, mig := range migrations {
		if _, ok := m.knownMigrations[mig.Version]; !ok {
			m.migrations = append(m.migrations, mig)
			m.knownMigrations[mig.Version] = struct{}{}
		}
	}
}

// Up executes all migrations in order they were added.
func (m *Migrator) Up(ctx context.Context) error {
	// ensure there is a table for known migrations
	executedMigrations, err := m.ExecutedMigrations(ctx)
	if err != nil {
		return fmt.Errorf("unable to create/get migrations: %w", err)
	}
	// ensure we close the insert statement, as it's not needed anymore
	defer m.close()
	return WithTransaction(m.db, func(txn *sql.Tx) error {
		for i := range m.migrations {
			migration := m.migrations[i]
			// Skip migration if it was already executed
			if _, ok := executedMigrations[migration.Version]; ok {
				continue
			}
			logrus.Debugf("Executing database migration '%s'", migration.Version)

			if err = migration.Up(ctx, txn); err != nil {
				return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
			}
			if err = m.insertMigration(ctx, txn, migration.Version); err != nil {
				return fmt.Errorf("unable to insert executed migrations: %w", err)
			}
		}
		return nil
	})
}

func (m *Migrator) insertMigration(ctx context.Context, txn *sql.Tx, migrationName string) error {
	if m.insertStmt == nil {
		var stmt *sql.Stmt
		var err error
		if txn == nil {
			stmt, err = m.db.PrepareContext(ctx, insertVersionSQL)
		} else {
			stmt, err = txn.PrepareContext(ctx, insertVersionSQL)
		}
		if err != nil {
			return fmt.Errorf("unable to prepare insert statement: %w", err)
		}
		m.insertStmt = stmt
	}
	stmt := TxStmtContext(ctx, txn, m.insertStmt)
	_, err := stmt.ExecContext(ctx,
		migrationName,
		time.Now().Format(time.RFC3339),
		internal.VersionString(),
	)
	return err
}

// ExecutedMigrations returns a map with already executed migrations in addition to creating the
// migrations table, if it doesn't exist.
func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{}, error) {
	result := make(map[string]struct{})
	_, err := m.db.ExecContext(ctx, createDBMigrationsSQL)
	if err != nil {
		return nil, fmt.Errorf("unable to create db_migrations: %w", err)
	}
	rows, err := m.db.QueryContext(ctx, selectDBMigrationsSQL)
	if err != nil {
		return nil, fmt.Errorf("unable to query db_migrations: %w", err)
	}
	defer internal.CloseAndLogIfError(ctx, rows, "ExecutedMigrations: rows.close() failed")
	var version string
	for rows.Next() {
		if err = rows.Scan(&version); err != nil {
			return nil, fmt.Errorf("unable to scan version: %w", err)
		}
		result[version] = struct{}{}
	}

	return result, rows.Err()
}

// InsertMigration creates the migrations table if it doesn't exist and
// inserts a migration given their name to the database.
// This should only be used when manually inserting migrations.
func InsertMigration(ctx context.Context, db *sql.DB, migrationName string) error {
	m := NewMigrator(db)
	defer m.close()
	existingMigrations, err := m.ExecutedMigrations(ctx)
	if err != nil {
		return err
	}
	if _, ok := existingMigrations[migrationName]; ok {
		return nil
	}
	return m.insertMigration(ctx, nil, migrationName)
}

func (m *Migrator) close() {
	if m.insertStmt != nil {
		internal.CloseAndLogIfError(context.Background(), m.insertStmt, "unable to close insert statement")
	}
}