feat(schema): Add support for migration hooks

This adds the ability to define hooks that will run prior to a schema
migration. The hooks should only be used when a migration would fail due
to data being in an invalid state, likely due to a bug in the previous
schema definition. For example if a column was missing a constraint, and
a new migration is attempting to add the constraint, there might already
be values that violate the new constraint. In such cases, we would want
the user to make a decision to address the data, and need to provide
them with details about what data is problematic in case they want to
manually address the problem. To help the user, if there is a sane way
to address the data, we should provide an option to automatically
"repair" the data, allowing the migration to continue. However, this
still should be an explicit opt-in from the user, and we need to clearly
describe what the repair option would do to the user so they understand
the consequences of opting in.

To accomplish this each migration can be provided a hook that contains a
CheckFunc, RepairFunc, and RepairDescription. These functions must
operate at the point-in-time of the migration immediately preceding the
current migration, as as such they should not use and domain packages or
repositories, instead they should use direct sql with the sql library.
The description is used to explain what the RepairFunc will do.

Hooks should be defined in `internal/db/schema/migrations/oss/oss.go`
and registered with the schema edition.

See the example in `internal/db/schema/manager_example_test.go` for more
details.
pull/2347/head
Timothy Messier 4 years ago committed by Damian Debkowski
parent 1fa79e9aa1
commit 79866a287f

@ -3,6 +3,7 @@ package database
import (
"context"
"fmt"
"strings"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/db/common"
@ -16,7 +17,7 @@ import (
// We expect the database already to be initialized iff initialized is set to true.
// Returns a cleanup function which must be called even if an error is returned and
// an error code where a non-zero value indicates an error happened.
func migrateDatabase(ctx context.Context, ui cli.Ui, dialect, u string, initialized bool, maxOpenConns int) (func(), int) {
func migrateDatabase(ctx context.Context, ui cli.Ui, dialect, u string, initialized bool, maxOpenConns int, selectedRepairs schema.RepairMigrations) (func(), int) {
noop := func() {}
// This database is used to keep an exclusive lock on the database for the
// remainder of the command
@ -30,7 +31,7 @@ func migrateDatabase(ctx context.Context, ui cli.Ui, dialect, u string, initiali
ui.Error(fmt.Sprintf("Unable to connect to the database at %q", u))
return noop, 2
}
man, err := schema.NewManager(ctx, schema.Dialect(dialect), dBase)
man, err := schema.NewManager(ctx, schema.Dialect(dialect), dBase, schema.WithRepairMigrations(selectedRepairs))
if err != nil {
if errors.Match(errors.T(errors.MigrationLock), err) {
ui.Error("Unable to capture a lock on the database.")
@ -62,21 +63,36 @@ func migrateDatabase(ctx context.Context, ui cli.Ui, dialect, u string, initiali
ui.Output(base.WrapAtLength("Database has already been initialized. Please use 'boundary database migrate' for any upgrade needs."))
return unlock, -1
}
if err := man.ApplyMigrations(ctx); err != nil {
repairLogs, err := man.ApplyMigrations(ctx)
if err != nil {
ui.Error(fmt.Errorf("Error running database migrations: %w", err).Error())
if checkErr, ok := err.(schema.MigrationCheckError); ok {
ui.Error(fmt.Errorf("%s", strings.Join(checkErr.Problems, "\n")).Error())
ui.Error(fmt.Sprintf("To automatically repair, use 'boundary database migrate -repair=%s:%d'. This will: %s", checkErr.Edition, checkErr.Version, checkErr.RepairDescription))
}
return unlock, 2
}
if base.Format(ui) == "table" {
ui.Info("Migrations successfully run.")
}
migrationLogs, err := man.GetMigrationLog(ctx)
if len(repairLogs) > 0 && base.Format(ui) == "table" {
ui.Info("Migration Repair logs...")
for _, e := range repairLogs {
ui.Info(fmt.Sprintf("%s:%d:", e.Edition, e.Version))
for _, entry := range e.Entry {
ui.Info(entry)
}
}
}
logs, err := man.GetMigrationLog(ctx)
if err != nil {
ui.Error(fmt.Errorf("Error retrieving database migration logs: %w", err).Error())
return unlock, 2
}
if len(migrationLogs) > 0 && base.Format(ui) == "table" {
if len(logs) > 0 && base.Format(ui) == "table" {
ui.Info("Migration Logs...")
for _, e := range migrationLogs {
for _, e := range logs {
ui.Info(e.Entry)
}
}

@ -57,7 +57,8 @@ func TestMigrateDatabase(t *testing.T) {
schema.TestCreatePartialEditions(schema.Dialect(dialect), schema.PartialEditions{"oss": earlyMigrationVersion}),
))
require.NoError(t, err)
require.NoError(t, man.ApplyMigrations(ctx))
_, err = man.ApplyMigrations(ctx)
require.NoError(t, err)
return u
},
expectedCode: 0,
@ -81,7 +82,8 @@ func TestMigrateDatabase(t *testing.T) {
schema.TestCreatePartialEditions(schema.Dialect(dialect), schema.PartialEditions{"oss": earlyMigrationVersion}),
))
require.NoError(t, err)
require.NoError(t, man.ApplyMigrations(ctx))
_, err = man.ApplyMigrations(ctx)
require.NoError(t, err)
return u
},
expectedCode: -1,
@ -193,7 +195,7 @@ func TestMigrateDatabase(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
u := tc.urlProvider()
ui := cli.NewMockUi()
clean, errCode := migrateDatabase(ctx, ui, dialect, u, tc.initialized, 10)
clean, errCode := migrateDatabase(ctx, ui, dialect, u, tc.initialized, 10, nil)
clean()
assert.EqualValues(t, tc.expectedCode, errCode)
assert.Equal(t, tc.expectedOutput, ui.OutputWriter.String())
@ -218,7 +220,8 @@ func TestVerifyOplogIsEmpty(t *testing.T) {
man, err := schema.NewManager(ctx, schema.Dialect(dialect), dBase)
require.NoError(t, err)
require.NoError(t, man.ApplyMigrations(ctx))
_, err = man.ApplyMigrations(ctx)
require.NoError(t, err)
cmd := InitCommand{Server: base.NewServer(base.NewCommand(cli.NewMockUi()))}

@ -266,7 +266,7 @@ func (c *InitCommand) Run(args []string) (retCode int) {
return base.CommandUserError
}
clean, errCode := migrateDatabase(c.Context, c.UI, dialect, migrationUrl, false, c.DatabaseMaxOpenConnections)
clean, errCode := migrateDatabase(c.Context, c.UI, dialect, migrationUrl, false, c.DatabaseMaxOpenConnections, nil)
defer clean()
switch errCode {
case 0:

@ -4,10 +4,13 @@ import (
"context"
"fmt"
"os"
"strconv"
"strings"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/observability/event"
host_plugin_assets "github.com/hashicorp/boundary/plugins/host"
@ -44,11 +47,14 @@ type MigrateCommand struct {
// deferred function on the Run method.
configWrapperCleanupFunc func() error
selectedRepairs schema.RepairMigrations
flagConfig string
flagConfigKms string
flagLogLevel string
flagLogFormat string
flagMigrationUrl string
flagRepairMigrations []string
flagAllowDevMigrations bool
}
@ -117,6 +123,12 @@ func (c *MigrateCommand) Flags() *base.FlagSets {
Usage: `If set, overrides a migration URL set in config, and specifies the URL used to connect to the database for migration. This can allow different permissions for the user running initialization or migration vs. normal operation. This can refer to a file on disk (file://) from which a URL will be read; an env var (env://) from which the URL will be read; or a direct database URL.`,
})
f.StringSliceVar(&base.StringSliceVar{
Name: "repair",
Target: &c.flagRepairMigrations,
Usage: `Run the repair function for the provided migration version.`,
})
return set
}
@ -256,7 +268,15 @@ plugins {
return base.CommandUserError
}
clean, errCode := migrateDatabase(c.Context, c.UI, dialect, migrationUrl, true, c.Config.Controller.Database.MaxOpenConnections)
clean, errCode := migrateDatabase(
c.Context,
c.UI,
dialect,
migrationUrl,
true,
c.Config.Controller.Database.MaxOpenConnections,
c.selectedRepairs,
)
defer clean()
if errCode != 0 {
return errCode
@ -275,6 +295,24 @@ func (c *MigrateCommand) ParseFlagsAndConfig(args []string) int {
return base.CommandUserError
}
c.selectedRepairs = make(schema.RepairMigrations)
for _, r := range c.flagRepairMigrations {
parts := strings.SplitN(r, ":", 2)
if len(parts) != 2 {
c.UI.Error(fmt.Sprintf("Error parsing repair option, invalid format: %s", r))
return base.CommandUserError
}
edition := parts[0]
version, err := strconv.Atoi(parts[1])
if err != nil {
c.UI.Error(fmt.Sprintf("Error parsing repair option %s, %s", r, err.Error()))
return base.CommandUserError
}
c.selectedRepairs.Add(edition, version)
}
// Validation
switch {
case len(c.flagConfig) == 0:

@ -36,7 +36,7 @@ var editions = dialects{
// - An unsupported dialect is provided.
// - The same (dialect, name) is registered.
// - The same (dialect, priority) is registered.
func RegisterEdition(name string, dialect Dialect, fs embed.FS, priority int) {
func RegisterEdition(name string, dialect Dialect, fs embed.FS, priority int, opt ...edition.Option) {
editions.Lock()
defer editions.Unlock()
@ -61,7 +61,11 @@ func RegisterEdition(name string, dialect Dialect, fs embed.FS, priority int) {
}
}
e = append(e, edition.New(name, dialect, fs, priority))
ee, err := edition.New(name, dialect, fs, priority, opt...)
if err != nil {
panic(err.Error)
}
e = append(e, ee)
e.Sort()
editions.m[dialect] = e

@ -1,10 +1,14 @@
package schema_test
import (
"context"
"database/sql"
"embed"
"testing"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/db/schema/internal/edition"
"github.com/hashicorp/boundary/internal/db/schema/migration"
"github.com/stretchr/testify/assert"
)
@ -24,12 +28,19 @@ var (
//go:embed testdata/three
three embed.FS
//go:embed testdata/hooks/initial
hooksInitial embed.FS
//go:embed testdata/hooks/updated
hooksUpdated embed.FS
)
func TestRegisterEditionPanics(t *testing.T) {
tests := []struct {
name string
editions []testEdition
opts []edition.Option
}{
{
"unsupportedDialect",
@ -41,6 +52,7 @@ func TestRegisterEditionPanics(t *testing.T) {
0,
},
},
nil,
},
{
"duplicateName",
@ -58,6 +70,7 @@ func TestRegisterEditionPanics(t *testing.T) {
1,
},
},
nil,
},
{
"duplicatePriority",
@ -75,6 +88,29 @@ func TestRegisterEditionPanics(t *testing.T) {
0,
},
},
nil,
},
{
"hookWithNoMigration",
[]testEdition{
{
"one",
schema.Postgres,
one,
0,
},
},
[]edition.Option{
edition.WithPreHooks(
map[int]*migration.Hook{
1099: {
CheckFunc: func(ctx context.Context, tx *sql.Tx) (migration.Problems, error) {
return nil, nil
},
},
},
),
},
},
}
@ -82,7 +118,7 @@ func TestRegisterEditionPanics(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
assert.Panics(t, func() {
for _, e := range tt.editions {
schema.RegisterEdition(e.name, e.dialect, e.fs, e.priority)
schema.RegisterEdition(e.name, e.dialect, e.fs, e.priority, tt.opts...)
}
}, tt.name)
})

@ -0,0 +1,20 @@
package schema
import (
"fmt"
"github.com/hashicorp/boundary/internal/db/schema/migration"
)
// MigrationCheckError is an error returned when a migration hook check function
// reports an error.
type MigrationCheckError struct {
Version int
Edition string
Problems migration.Problems
RepairDescription string
}
func (e MigrationCheckError) Error() string {
return fmt.Sprintf("check failed for %s:%d", e.Edition, e.Version)
}

@ -10,6 +10,8 @@ import (
"sort"
"strconv"
"strings"
"github.com/hashicorp/boundary/internal/db/schema/migration"
)
// Dialect is a specific SQL language variant. This generally is the same as
@ -34,7 +36,7 @@ type Edition struct {
// The set of migrations that should be applied to a database to reach the latest version.
// This is a map of schema versions to sql.
Migrations map[int][]byte
Migrations migration.Migrations
// Priority is used to determine the order that multiple Editions should be applied.
Priority int
@ -72,13 +74,16 @@ func (e Editions) Sort() {
// 2/
// 01_add_new_table.up.sql
// 02_refactor_views.up.sql
func New(name string, dialect Dialect, m embed.FS, priority int) Edition {
func New(name string, dialect Dialect, m embed.FS, priority int, opt ...Option) (Edition, error) {
var largestSchemaVersion int
migrations := make(map[int][]byte)
migrations := make(migration.Migrations)
opts := getOpts(opt...)
prehook := opts.withPreHooks
fs.WalkDir(m, ".", func(path string, d fs.DirEntry, err error) error {
err := fs.WalkDir(m, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
panic(fmt.Sprintf("unable to process migration files: %s", err))
return fmt.Errorf("unable to process migration files: %s", err)
}
if d.IsDir() {
@ -94,12 +99,12 @@ func New(name string, dialect Dialect, m embed.FS, priority int) Edition {
verMajor, err := strconv.Atoi(verMajorDir)
if err != nil {
panic(fmt.Sprintf("migration file does not have valid major version directory: %s", path))
return fmt.Errorf("migration file does not have valid major version directory: %s", path)
}
verMinor, err := strconv.Atoi(strings.SplitN(file, "_", 2)[0])
if err != nil {
panic(fmt.Sprintf("migration file does not have valid minor version prefix: %s", path))
return fmt.Errorf("migration file does not have valid minor version prefix: %s", path)
}
fullV := (verMajor * 1000) + verMinor
@ -109,7 +114,7 @@ func New(name string, dialect Dialect, m embed.FS, priority int) Edition {
cbts, err := m.ReadFile(path)
if err != nil {
panic(fmt.Sprintf("unable to read migration file: %s", path))
return fmt.Errorf("unable to read migration file: %s", path)
}
contents := strings.TrimSpace(string(cbts))
@ -122,12 +127,27 @@ func New(name string, dialect Dialect, m embed.FS, priority int) Edition {
contents = strings.TrimSpace(contents)
if _, exists := migrations[fullV]; exists {
panic(fmt.Sprintf("migration file for version %d already exists", fullV))
return fmt.Errorf("migration file for version %d already exists", fullV)
}
migrations[fullV] = migration.Migration{
Edition: name,
Statements: []byte(contents),
Version: fullV,
PreHook: prehook[fullV],
}
migrations[fullV] = []byte(contents)
return nil
})
if err != nil {
return Edition{}, err
}
for k := range prehook {
_, ok := migrations[k]
if !ok {
return Edition{}, fmt.Errorf("prehook for version %d does not correspond with a migration", k)
}
}
return Edition{
Name: name,
@ -135,5 +155,5 @@ func New(name string, dialect Dialect, m embed.FS, priority int) Edition {
LatestVersion: largestSchemaVersion,
Migrations: migrations,
Priority: priority,
}
}, nil
}

@ -52,7 +52,8 @@ func TestNew(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := edition.New(tt.name, edition.Dialect("postgres"), tt.fs, tt.priority)
e, err := edition.New(tt.name, edition.Dialect("postgres"), tt.fs, tt.priority)
assert.NoError(t, err)
assert.Equal(t, e.Name, tt.name, "Name")
assert.Equal(t, e.Dialect, edition.Dialect("postgres"), "Dialect")
assert.Equal(t, e.LatestVersion, tt.expectedVersion, "Version")
@ -78,7 +79,7 @@ var (
duplicateVersions embed.FS
)
func TestNewPanics(t *testing.T) {
func TestNewErrors(t *testing.T) {
t.Parallel()
tests := []struct {
@ -112,9 +113,8 @@ func TestNewPanics(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Panics(t, func() {
edition.New(tt.name, edition.Dialect("postgres"), tt.fs, 0)
}, tt.name)
_, err := edition.New(tt.name, edition.Dialect("postgres"), tt.fs, 0)
assert.Error(t, err)
})
}
}

@ -0,0 +1,32 @@
package edition
import "github.com/hashicorp/boundary/internal/db/schema/migration"
// getOpts - iterate the inbound Options and return a struct.
func getOpts(opt ...Option) options {
opts := getDefaultOptions()
for _, o := range opt {
o(&opts)
}
return opts
}
// Option - how Options are passed as arguments.
type Option func(*options)
// options = how options are represented
type options struct {
withPreHooks map[int]*migration.Hook
}
func getDefaultOptions() options {
return options{}
}
// WithPreHooks provides an option to specify the set of migration hooks
// for a correspondings migration.
func WithPreHooks(h map[int]*migration.Hook) Option {
return func(o *options) {
o.withPreHooks = h
}
}

@ -37,6 +37,7 @@ import (
"io/ioutil"
"github.com/hashicorp/boundary/internal/db/schema/internal/log"
"github.com/hashicorp/boundary/internal/db/schema/migration"
"github.com/hashicorp/boundary/internal/db/schema/migrations"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/go-multierror"
@ -157,6 +158,34 @@ func (p *Postgres) StartRun(ctx context.Context) error {
return nil
}
// CheckHook is a hook that runs prior to a migration's statements.
// It should run in the same transaction as a corresponding Run call.
func (p *Postgres) CheckHook(ctx context.Context, f migration.CheckFunc) (migration.Problems, error) {
const op = "postgres.(Postgres).CheckHook"
if p.tx == nil {
return nil, errors.New(ctx, errors.MigrationIntegrity, op, "no pending transaction")
}
if f == nil {
return nil, errors.New(ctx, errors.MigrationIntegrity, op, "no check function")
}
return f(ctx, p.tx)
}
// RepairHook is a hook that runs prior to a migration's statements.
// It should run in the same transaction a corresponding Run call.
func (p *Postgres) RepairHook(ctx context.Context, f migration.RepairFunc) (migration.Repairs, error) {
const op = "postgres.(Postgres).RepairHook"
if p.tx == nil {
return nil, errors.New(ctx, errors.MigrationIntegrity, op, "no pending transaction")
}
if f == nil {
return nil, errors.New(ctx, errors.MigrationIntegrity, op, "no repair function")
}
return f(ctx, p.tx)
}
// CommitRun commits a transaction, if there is an error it should rollback the transaction.
func (p *Postgres) CommitRun(ctx context.Context) error {
const op = "postgres.(Postgres).CommitRun"
@ -198,7 +227,7 @@ func (p *Postgres) Run(ctx context.Context, migration io.Reader, version int, ed
return errors.Wrap(ctx, err, op)
}
if _, err := p.conn.ExecContext(ctx, query); err != nil {
if _, err := p.tx.ExecContext(ctx, query); err != nil {
if pgErr, ok := err.(*pgconn.PgError); ok {
var line uint
var col uint

@ -7,20 +7,15 @@ import (
"sort"
"github.com/hashicorp/boundary/internal/db/schema/internal/edition"
"github.com/hashicorp/boundary/internal/db/schema/migration"
)
const nilVersion = -1
type migration struct {
version int
edition string
statements []byte
}
// Provider provides the migrations to the schema.Manager in the correct order.
type Provider struct {
pos int
migrations []migration
migrations []migration.Migration
}
// DatabaseState is a map of edition names to versions.
@ -37,7 +32,7 @@ func New(dbState DatabaseState, editions edition.Editions) *Provider {
// ensure editions in priority order
editions.Sort()
allMigrations := make([]migration, 0)
allMigrations := make([]migration.Migration, 0)
for _, e := range editions {
dbVer, ok := dbState[e.Name]
@ -45,19 +40,15 @@ func New(dbState DatabaseState, editions edition.Editions) *Provider {
dbVer = nilVersion
}
migrations := make([]migration, 0, len(e.Migrations))
for ver, statements := range e.Migrations {
migrations := make([]migration.Migration, 0, len(e.Migrations))
for ver, m := range e.Migrations {
if ver > dbVer {
migrations = append(migrations, migration{
version: ver,
edition: e.Name,
statements: statements,
})
migrations = append(migrations, m)
}
}
sort.SliceStable(migrations, func(i, j int) bool {
return migrations[i].version < migrations[j].version
return migrations[i].Version < migrations[j].Version
})
allMigrations = append(allMigrations, migrations...)
@ -80,7 +71,7 @@ func (p *Provider) Version() int {
if p.pos < 0 || p.pos >= len(p.migrations) {
return -1
}
return p.migrations[p.pos].version
return p.migrations[p.pos].Version
}
// Edition returns the edition name for the current migration.
@ -88,7 +79,7 @@ func (p *Provider) Edition() string {
if p.pos < 0 || p.pos >= len(p.migrations) {
return ""
}
return p.migrations[p.pos].edition
return p.migrations[p.pos].Edition
}
// Statements returns the sql statements name for the current migration.
@ -96,5 +87,13 @@ func (p *Provider) Statements() []byte {
if p.pos < 0 || p.pos >= len(p.migrations) {
return nil
}
return p.migrations[p.pos].statements
return p.migrations[p.pos].Statements
}
// PreHook returns the hooks that should be run prior to the current migration.
func (p *Provider) PreHook() *migration.Hook {
if p.pos < 0 || p.pos >= len(p.migrations) {
return nil
}
return p.migrations[p.pos].PreHook
}

@ -1,10 +1,13 @@
package provider_test
import (
"context"
"database/sql"
"testing"
"github.com/hashicorp/boundary/internal/db/schema/internal/edition"
"github.com/hashicorp/boundary/internal/db/schema/internal/provider"
"github.com/hashicorp/boundary/internal/db/schema/migration"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -13,11 +16,18 @@ type expectedMigration struct {
version int
edition string
statements []byte
prehook *migration.Hook
}
type expectedMigrations []expectedMigration
func TestProvider(t *testing.T) {
testHook := &migration.Hook{
CheckFunc: func(ctx context.Context, tx *sql.Tx) (migration.Problems, error) {
return nil, nil
},
}
tests := []struct {
name string
editions edition.Editions
@ -30,17 +40,53 @@ func TestProvider(t *testing.T) {
edition.Edition{
Name: "one",
LatestVersion: 2,
Migrations: map[int][]byte{
1: []byte(`migration one`),
2: []byte(`migration two`),
Migrations: migration.Migrations{
1: migration.Migration{
Statements: []byte(`migration one`),
Edition: "one",
Version: 1,
},
2: migration.Migration{
Statements: []byte(`migration two`),
Edition: "one",
Version: 2,
},
},
Priority: 0,
},
},
provider.DatabaseState{"one": -1},
expectedMigrations{
{1, "one", []byte(`migration one`)},
{2, "one", []byte(`migration two`)},
{1, "one", []byte(`migration one`), nil},
{2, "one", []byte(`migration two`), nil},
},
},
{
"oneEditionNoneAppliedWithHook",
edition.Editions{
edition.Edition{
Name: "one",
LatestVersion: 2,
Migrations: migration.Migrations{
1: migration.Migration{
Statements: []byte(`migration one`),
Edition: "one",
Version: 1,
},
2: migration.Migration{
Statements: []byte(`migration two`),
Edition: "one",
Version: 2,
PreHook: testHook,
},
},
Priority: 0,
},
},
provider.DatabaseState{"one": -1},
expectedMigrations{
{1, "one", []byte(`migration one`), nil},
{2, "one", []byte(`migration two`), testHook},
},
},
{
@ -49,28 +95,40 @@ func TestProvider(t *testing.T) {
edition.Edition{
Name: "one",
LatestVersion: 2,
Migrations: map[int][]byte{
1: []byte(`migration one`),
2: []byte(`migration two`),
Migrations: migration.Migrations{
1: migration.Migration{
Statements: []byte(`migration one`),
Edition: "one",
Version: 1,
},
2: migration.Migration{
Statements: []byte(`migration two`),
Edition: "one",
Version: 2,
},
},
Priority: 0,
},
edition.Edition{
Name: "two",
LatestVersion: 1,
Migrations: map[int][]byte{
1: []byte(`migration one`),
Migrations: migration.Migrations{
1: migration.Migration{
Statements: []byte(`migration one`),
Edition: "two",
Version: 1,
},
},
Priority: 0,
Priority: 1,
},
},
provider.DatabaseState{
"one": -1,
},
expectedMigrations{
{1, "one", []byte(`migration one`)},
{2, "one", []byte(`migration two`)},
{1, "two", []byte(`migration one`)},
{1, "one", []byte(`migration one`), nil},
{2, "one", []byte(`migration two`), nil},
{1, "two", []byte(`migration one`), nil},
},
},
{
@ -79,19 +137,31 @@ func TestProvider(t *testing.T) {
edition.Edition{
Name: "one",
LatestVersion: 2,
Migrations: map[int][]byte{
1: []byte(`migration one`),
2: []byte(`migration two`),
Migrations: migration.Migrations{
1: migration.Migration{
Statements: []byte(`migration one`),
Edition: "one",
Version: 1,
},
2: migration.Migration{
Statements: []byte(`migration two`),
Edition: "one",
Version: 2,
},
},
Priority: 0,
},
edition.Edition{
Name: "two",
LatestVersion: 1,
Migrations: map[int][]byte{
1: []byte(`migration one`),
Migrations: migration.Migrations{
1: migration.Migration{
Statements: []byte(`migration one`),
Edition: "two",
Version: 1,
},
},
Priority: 0,
Priority: 1,
},
},
provider.DatabaseState{
@ -99,8 +169,8 @@ func TestProvider(t *testing.T) {
"two": -1,
},
expectedMigrations{
{2, "one", []byte(`migration two`)},
{1, "two", []byte(`migration one`)},
{2, "one", []byte(`migration two`), nil},
{1, "two", []byte(`migration one`), nil},
},
},
{
@ -109,20 +179,36 @@ func TestProvider(t *testing.T) {
edition.Edition{
Name: "one",
LatestVersion: 2,
Migrations: map[int][]byte{
1: []byte(`migration one`),
2: []byte(`migration two`),
Migrations: migration.Migrations{
1: migration.Migration{
Statements: []byte(`migration one`),
Edition: "one",
Version: 1,
},
2: migration.Migration{
Statements: []byte(`migration two`),
Edition: "one",
Version: 2,
},
},
Priority: 0,
},
edition.Edition{
Name: "two",
LatestVersion: 1,
Migrations: map[int][]byte{
1: []byte(`migration one`),
2: []byte(`migration two`),
Migrations: migration.Migrations{
1: migration.Migration{
Statements: []byte(`migration one`),
Edition: "two",
Version: 1,
},
2: migration.Migration{
Statements: []byte(`migration two`),
Edition: "two",
Version: 2,
},
},
Priority: 0,
Priority: 1,
},
},
provider.DatabaseState{
@ -130,8 +216,8 @@ func TestProvider(t *testing.T) {
"two": 1,
},
expectedMigrations{
{2, "one", []byte(`migration two`)},
{2, "two", []byte(`migration two`)},
{2, "one", []byte(`migration two`), nil},
{2, "two", []byte(`migration two`), nil},
},
},
}
@ -144,15 +230,17 @@ func TestProvider(t *testing.T) {
next := p.Next()
require.True(t, next)
assert.Equal(t, p.Version(), expected.version, tt.name)
assert.Equal(t, p.Edition(), expected.edition, tt.name)
assert.Equal(t, p.Statements(), expected.statements, tt.name)
assert.Equal(t, expected.version, p.Version(), tt.name)
assert.Equal(t, expected.edition, p.Edition(), tt.name)
assert.Equal(t, expected.statements, p.Statements(), tt.name)
assert.Equal(t, expected.prehook, p.PreHook(), tt.name)
}
assert.False(t, p.Next(), tt.name)
assert.Equal(t, p.Version(), -1, tt.name)
assert.Equal(t, p.Edition(), "", tt.name)
assert.Equal(t, -1, p.Version(), tt.name)
assert.Equal(t, "", p.Edition(), tt.name)
assert.Nil(t, p.Statements(), tt.name)
assert.Nil(t, p.PreHook(), tt.name)
})
}
}

@ -11,6 +11,7 @@ import (
"github.com/hashicorp/boundary/internal/db/schema/internal/log"
"github.com/hashicorp/boundary/internal/db/schema/internal/postgres"
"github.com/hashicorp/boundary/internal/db/schema/internal/provider"
"github.com/hashicorp/boundary/internal/db/schema/migration"
"github.com/hashicorp/boundary/internal/errors"
)
@ -25,10 +26,16 @@ type driver interface {
StartRun(context.Context) error
// CommitRun commits a transaction, if there is an error it should rollback the transaction.
CommitRun(context.Context) error
// Run will apply a migration. The io.Reader should provide the SQL
// CheckHook is a hook that runs prior to a migration's statements.
// It should run in the same transaction a corresponding Run call.
CheckHook(context.Context, migration.CheckFunc) (migration.Problems, error)
// RepairHook is a hook that runs prior to a migration's statements.
// It should run in the same transaction a corresponding Run call.
RepairHook(context.Context, migration.RepairFunc) (migration.Repairs, error)
// Run will apply a migrations statements. The io.Reader should provide the SQL
// statements to execute, and the int is the version for that set of
// statements. This should always be wrapped by StartRun and CommitRun.
Run(ctx context.Context, migration io.Reader, version int, edition string) error
Run(ctx context.Context, statements io.Reader, version int, edition string) error
// CurrentState returns the state of the given edition.
// ver is the current migration version number as recorded in the database.
// A version of -1 indicates no version is set.
@ -51,10 +58,11 @@ type driver interface {
// the underlying boundary database schema.
// Manager is not thread safe.
type Manager struct {
db *sql.DB
driver driver
dialect string
editions edition.Editions
db *sql.DB
driver driver
dialect string
editions edition.Editions
selectedRepairs RepairMigrations
}
// NewManager creates a new schema manager. An error is returned
@ -65,8 +73,12 @@ func NewManager(ctx context.Context, dialect Dialect, db *sql.DB, opt ...Option)
editions.Lock()
defer editions.Unlock()
dbM := Manager{db: db, dialect: string(dialect)}
opts := getOpts(opt...)
dbM := Manager{
db: db,
dialect: string(dialect),
selectedRepairs: opts.withRepairMigrations,
}
if opts.withEditions != nil {
dbM.editions = opts.withEditions
} else {
@ -75,6 +87,7 @@ func NewManager(ctx context.Context, dialect Dialect, db *sql.DB, opt ...Option)
dbM.editions = append(dbM.editions, e)
}
}
switch dialect {
case "postgres":
var err error
@ -156,12 +169,12 @@ func (b *Manager) ExclusiveUnlock(ctx context.Context) error {
// ApplyMigrations updates the database schema to match the latest version known by
// the boundary binary. An error is not returned if the database is already at
// the most recent version.
func (b *Manager) ApplyMigrations(ctx context.Context) error {
func (b *Manager) ApplyMigrations(ctx context.Context) ([]RepairLog, error) {
const op = "schema.(Manager).ApplyMigrations"
// Capturing a lock that this session to the db already possesses is okay.
if err := b.driver.Lock(ctx); err != nil {
return errors.Wrap(ctx, err, op)
return nil, errors.Wrap(ctx, err, op)
}
defer func() {
if err := b.driver.Unlock(ctx); err != nil {
@ -173,24 +186,28 @@ func (b *Manager) ApplyMigrations(ctx context.Context) error {
state, err := b.CurrentState(ctx)
if err != nil {
return errors.Wrap(ctx, err, op)
return nil, errors.Wrap(ctx, err, op)
}
if err = b.runMigrations(ctx, provider.New(state.databaseState(), b.editions)); err != nil {
return errors.Wrap(ctx, err, op)
logs, err := b.runMigrations(ctx, provider.New(state.databaseState(), b.editions))
if err != nil {
return nil, err
}
return nil
return logs, nil
}
// runMigrations passes migration queries to a database driver and manages
// the version and dirty bit. Cancellation or deadline/timeout is managed
// through the passed in context.
func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) (err error) {
func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]RepairLog, error) {
const op = "schema.(Manager).runMigrations"
var logEntries []RepairLog
var err error
if startErr := b.driver.StartRun(ctx); startErr != nil {
err = errors.Wrap(ctx, startErr, op)
return err
return nil, err
}
defer func() {
@ -201,27 +218,55 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) (err
if ensureErr := b.driver.EnsureVersionTable(ctx); ensureErr != nil {
err = errors.Wrap(ctx, ensureErr, op)
return err
return nil, err
}
if ensureErr := b.driver.EnsureMigrationLogTable(ctx); ensureErr != nil {
err = errors.Wrap(ctx, ensureErr, op)
return err
return nil, err
}
for p.Next() {
select {
case <-ctx.Done():
err = errors.Wrap(ctx, ctx.Err(), op)
return err
return nil, err
default:
// context is not done yet. Continue on to the next query to execute.
}
if runErr := b.driver.Run(ctx, bytes.NewReader(p.Statements()), p.Version(), p.Edition()); err != nil {
if h := p.PreHook(); h != nil {
problems, err := b.driver.CheckHook(ctx, h.CheckFunc)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
if len(problems) > 0 {
if !b.selectedRepairs.IsSet(p.Edition(), p.Version()) {
return nil, MigrationCheckError{
Version: p.Version(),
Edition: p.Edition(),
Problems: problems,
RepairDescription: h.RepairDescription,
}
}
repairs, err := b.driver.RepairHook(ctx, h.RepairFunc)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
logEntries = append(logEntries, RepairLog{
Version: p.Version(),
Edition: p.Edition(),
Entry: repairs,
})
}
}
if runErr := b.driver.Run(ctx, bytes.NewReader(p.Statements()), p.Version(), p.Edition()); runErr != nil {
err = errors.Wrap(ctx, runErr, op)
return err
return nil, err
}
}
return nil
return logEntries, nil
}

@ -0,0 +1,181 @@
package schema_test
import (
"context"
"database/sql"
"fmt"
"log"
"strings"
"github.com/hashicorp/boundary/internal/db/common"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/db/schema/internal/edition"
"github.com/hashicorp/boundary/internal/db/schema/migration"
"github.com/hashicorp/boundary/testing/dbtest"
)
func ExampleManager_hooks() {
ctx := context.Background()
dialect := dbtest.Postgres
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
defer c()
d, err := common.SqlOpen(dialect, u)
if err != nil {
log.Fatalf(err.Error())
}
editions := edition.Editions{
{
Name: "hooks_example",
Dialect: schema.Postgres,
Migrations: migration.Migrations{
1: migration.Migration{
Edition: "hooks_example",
Version: 1,
Statements: []byte(`
create table foo (
id bigint generated always as identity primary key,
public_id text,
name text
);
-- Not a normal thing to have in a migration
-- but this is done to put "invalid" data
-- into a table, that will then have a constraint added
-- in a future migration.
insert into foo
(public_id, name)
values
(null, 'Alice'),
(null, 'Bob'),
('foo_cathy', 'Cathy');
`),
},
2: migration.Migration{
Edition: "hooks_example",
Version: 2,
Statements: []byte(`
-- this would fail if data is not updated first
alter table foo
alter column public_id
set not null;
`),
PreHook: &migration.Hook{
CheckFunc: func(ctx context.Context, tx *sql.Tx) (migration.Problems, error) {
rows, err := tx.QueryContext(
ctx,
`select
id, name
from foo
where
public_id is null`,
)
if err != nil {
return nil, err
}
invalid := make([]string, 0)
for rows.Next() {
var id int
var name string
if err := rows.Scan(&id, &name); err != nil {
return nil, err
}
invalid = append(invalid, fmt.Sprintf("%d:%s", id, name))
}
if len(invalid) > 0 {
return append([]string{"invalid foos:"}, invalid...), nil
}
return nil, nil
},
RepairFunc: func(ctx context.Context, tx *sql.Tx) (migration.Repairs, error) {
rows, err := tx.QueryContext(
ctx,
`delete
from foo
where
public_id is null
returning
id, name;
`,
)
if err != nil {
return nil, err
}
invalid := make([]string, 0)
for rows.Next() {
var id int
var name string
if err := rows.Scan(&id, &name); err != nil {
return nil, err
}
invalid = append(invalid, fmt.Sprintf("%d:%s", id, name))
}
if len(invalid) > 0 {
return append([]string{"deleted foos:"}, invalid...), nil
}
return nil, nil
},
RepairDescription: "will delete any foo that has a null public_id",
},
},
},
Priority: 0,
},
}
// Run manager with marking any migrations for repair.
// The check function in the hook should detect a problem and
// fail the migration.
m, err := schema.NewManager(
ctx,
schema.Dialect(dialect),
d,
schema.WithEditions(editions),
)
if err != nil {
log.Fatalf(err.Error())
}
_, err = m.ApplyMigrations(ctx)
checkErr, _ := err.(schema.MigrationCheckError)
fmt.Println(checkErr.Error())
fmt.Println(strings.Join(checkErr.Problems, "\n"))
fmt.Printf("repair: %s\n", checkErr.RepairDescription)
// Now run with the migration marked for repair.
// The repair function should run, delete data, and the migration
// will succeed.
m, err = schema.NewManager(
ctx,
schema.Dialect(dialect),
d,
schema.WithEditions(editions),
schema.WithRepairMigrations(schema.RepairMigrations{
"hooks_example": map[int]bool{
2: true,
},
}),
)
logs, err := m.ApplyMigrations(ctx)
if err != nil {
log.Fatalf(err.Error())
}
for _, log := range logs {
fmt.Printf("%s:%d:\n", log.Edition, log.Version)
fmt.Println(strings.Join(log.Entry, "\n"))
}
// Output: check failed for hooks_example:2
// invalid foos:
// 1:Alice
// 2:Bob
// repair: will delete any foo that has a null public_id
// hooks_example:2:
// deleted foos:
// 1:Alice
// 2:Bob
}

@ -2,12 +2,15 @@ package schema_test
import (
"context"
"database/sql"
"fmt"
"sort"
"testing"
"github.com/hashicorp/boundary/internal/db/common"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/db/schema/internal/edition"
"github.com/hashicorp/boundary/internal/db/schema/migration"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/testing/dbtest"
@ -57,8 +60,12 @@ func TestCurrentState(t *testing.T) {
Name: "oss",
Dialect: schema.Postgres,
LatestVersion: 2,
Migrations: map[int][]byte{
2: []byte(`select 1`),
Migrations: migration.Migrations{
2: migration.Migration{
Statements: []byte(`select 1`),
Edition: "oss",
Version: 2,
},
},
Priority: 0,
},
@ -79,7 +86,8 @@ func TestCurrentState(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, want, s)
assert.NoError(t, m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
assert.NoError(t, err)
want = &schema.State{
Initialized: true,
@ -108,7 +116,10 @@ func TestApplyMigration(t *testing.T) {
{
"oneEdition",
edition.Editions{
edition.New("one", schema.Postgres, one, 0),
func() edition.Edition {
e, _ := edition.New("one", schema.Postgres, one, 0)
return e
}(),
},
false,
&schema.State{
@ -126,8 +137,14 @@ func TestApplyMigration(t *testing.T) {
{
"twoEditions",
edition.Editions{
edition.New("one", schema.Postgres, one, 0),
edition.New("two", schema.Postgres, two, 1),
func() edition.Edition {
e, _ := edition.New("one", schema.Postgres, one, 0)
return e
}(),
func() edition.Edition {
e, _ := edition.New("two", schema.Postgres, two, 1)
return e
}(),
},
false,
&schema.State{
@ -151,8 +168,14 @@ func TestApplyMigration(t *testing.T) {
{
"twoEditionsIncorrectPriority",
edition.Editions{
edition.New("one", schema.Postgres, one, 1),
edition.New("two", schema.Postgres, two, 0),
func() edition.Edition {
e, _ := edition.New("one", schema.Postgres, one, 1)
return e
}(),
func() edition.Edition {
e, _ := edition.New("two", schema.Postgres, two, 0)
return e
}(),
},
true,
&schema.State{
@ -176,9 +199,18 @@ func TestApplyMigration(t *testing.T) {
{
"threeEditions",
edition.Editions{
edition.New("one", schema.Postgres, one, 0),
edition.New("two", schema.Postgres, two, 1),
edition.New("three", schema.Postgres, three, 2),
func() edition.Edition {
e, _ := edition.New("one", schema.Postgres, one, 0)
return e
}(),
func() edition.Edition {
e, _ := edition.New("two", schema.Postgres, two, 1)
return e
}(),
func() edition.Edition {
e, _ := edition.New("three", schema.Postgres, three, 2)
return e
}(),
},
false,
&schema.State{
@ -208,9 +240,18 @@ func TestApplyMigration(t *testing.T) {
{
"threeEditionsIncorrectPriority",
edition.Editions{
edition.New("one", schema.Postgres, one, 0),
edition.New("two", schema.Postgres, two, 2),
edition.New("three", schema.Postgres, three, 1),
func() edition.Edition {
e, _ := edition.New("one", schema.Postgres, one, 0)
return e
}(),
func() edition.Edition {
e, _ := edition.New("two", schema.Postgres, two, 2)
return e
}(),
func() edition.Edition {
e, _ := edition.New("three", schema.Postgres, three, 1)
return e
}(),
},
true,
&schema.State{
@ -257,9 +298,11 @@ func TestApplyMigration(t *testing.T) {
m, err := schema.NewManager(ctx, schema.Dialect(dialect), d, schema.WithEditions(tt.editions))
require.NoError(t, err)
if tt.expectErr {
assert.Error(t, m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
assert.Error(t, err)
} else {
assert.NoError(t, m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
assert.NoError(t, err)
}
s, err := m.CurrentState(ctx)
@ -276,6 +319,262 @@ func TestApplyMigration(t *testing.T) {
}
}
func TestApplyMigrationWithHooks(t *testing.T) {
tests := []struct {
name string
editions edition.Editions
repairs schema.RepairMigrations
expectErr error
state *schema.State
repairLogs []schema.RepairLog
}{
{
"checkPass",
edition.Editions{
func() edition.Edition {
e, _ := edition.New(
"hooks",
schema.Postgres,
hooksUpdated,
0,
edition.WithPreHooks(
map[int]*migration.Hook{
1001: {
CheckFunc: func(ctx context.Context, tx *sql.Tx) (migration.Problems, error) {
return nil, nil
},
},
},
),
)
return e
}(),
},
nil,
nil,
&schema.State{
Initialized: true,
Editions: []schema.EditionState{
{
Name: "hooks",
BinarySchemaVersion: 1001,
DatabaseSchemaVersion: 1001,
DatabaseSchemaState: schema.Equal,
},
},
},
nil,
},
{
"checkFailure",
edition.Editions{
func() edition.Edition {
e, _ := edition.New(
"hooks",
schema.Postgres,
hooksUpdated,
0,
edition.WithPreHooks(
map[int]*migration.Hook{
1001: {
CheckFunc: func(ctx context.Context, tx *sql.Tx) (migration.Problems, error) {
return migration.Problems{"failed"}, nil
},
RepairDescription: "repair all the things",
},
},
),
)
return e
}(),
},
nil,
schema.MigrationCheckError{
Version: 1001,
Edition: "hooks",
Problems: migration.Problems{"failed"},
RepairDescription: "repair all the things",
},
&schema.State{
Initialized: true,
Editions: []schema.EditionState{
{
Name: "hooks",
BinarySchemaVersion: 1001,
DatabaseSchemaVersion: 1,
DatabaseSchemaState: schema.Behind,
},
},
},
nil,
},
{
"repair",
edition.Editions{
func() edition.Edition {
e, _ := edition.New(
"hooks",
schema.Postgres,
hooksUpdated,
0,
edition.WithPreHooks(
map[int]*migration.Hook{
1001: {
CheckFunc: func(ctx context.Context, tx *sql.Tx) (migration.Problems, error) {
return migration.Problems{"failed"}, nil
},
RepairFunc: func(ctx context.Context, tx *sql.Tx) (migration.Repairs, error) {
return migration.Repairs{"repaired all the things"}, nil
},
},
},
),
)
return e
}(),
},
schema.RepairMigrations{
"hooks": map[int]bool{
1001: true,
},
},
nil,
&schema.State{
Initialized: true,
Editions: []schema.EditionState{
{
Name: "hooks",
BinarySchemaVersion: 1001,
DatabaseSchemaVersion: 1001,
DatabaseSchemaState: schema.Equal,
},
},
},
[]schema.RepairLog{
{
Edition: "hooks",
Version: 1001,
Entry: migration.Repairs{"repaired all the things"},
},
},
},
{
"repairRequestNoRepairFunc",
edition.Editions{
func() edition.Edition {
e, _ := edition.New(
"hooks",
schema.Postgres,
hooksUpdated,
0,
edition.WithPreHooks(
map[int]*migration.Hook{
1001: {
CheckFunc: func(ctx context.Context, tx *sql.Tx) (migration.Problems, error) {
return migration.Problems{"failed"}, nil
},
},
},
),
)
return e
}(),
},
schema.RepairMigrations{
"hooks": map[int]bool{
1001: true,
},
},
fmt.Errorf("schema.(Manager).runMigrations: postgres.(Postgres).RepairHook: no repair function: integrity violation: error #2000"),
&schema.State{
Initialized: true,
Editions: []schema.EditionState{
{
Name: "hooks",
BinarySchemaVersion: 1001,
DatabaseSchemaVersion: 1,
DatabaseSchemaState: schema.Behind,
},
},
},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dialect := dbtest.Postgres
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
t.Cleanup(func() {
if err := c(); err != nil {
t.Fatalf("Got error at cleanup: %v", err)
}
})
require.NoError(t, err)
ctx := context.Background()
d, err := common.SqlOpen(dialect, u)
require.NoError(t, err)
m, err := schema.NewManager(
ctx,
schema.Dialect(dialect),
d,
schema.WithEditions(
edition.Editions{
func() edition.Edition {
e, _ := edition.New(
"hooks",
schema.Postgres,
hooksInitial,
0,
)
return e
}(),
},
),
)
require.NoError(t, err)
logs, err := m.ApplyMigrations(ctx)
assert.NoError(t, err)
assert.Empty(t, logs)
m, err = schema.NewManager(
ctx,
schema.Dialect(dialect),
d,
schema.WithEditions(tt.editions),
schema.WithRepairMigrations(tt.repairs),
)
require.NoError(t, err)
logs, err = m.ApplyMigrations(ctx)
if tt.expectErr != nil {
assert.EqualError(t, tt.expectErr, err.Error())
if want, ok := tt.expectErr.(schema.MigrationCheckError); ok {
got, ok := err.(schema.MigrationCheckError)
assert.True(t, ok, "not a schema.MigrationCheckError")
assert.Equal(t, want, got)
}
} else {
assert.NoError(t, err)
}
assert.ElementsMatch(t, tt.repairLogs, logs)
s, err := m.CurrentState(ctx)
require.NoError(t, err)
assert.Equal(t, tt.state.Initialized, s.Initialized)
assert.ElementsMatch(t, tt.state.Editions, s.Editions)
if tt.expectErr != nil {
assert.False(t, s.MigrationsApplied())
} else {
assert.True(t, s.MigrationsApplied())
}
})
}
}
func TestApplyMigration_canceledContext(t *testing.T) {
dialect := dbtest.Postgres
@ -293,7 +592,8 @@ func TestApplyMigration_canceledContext(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
cancel()
assert.Error(t, m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
assert.Error(t, err)
}
func TestApplyMigrations_BadSQL(t *testing.T) {
@ -315,15 +615,20 @@ func TestApplyMigrations_BadSQL(t *testing.T) {
Name: "oss",
Dialect: schema.Postgres,
LatestVersion: 1,
Migrations: map[int][]byte{
1: []byte(`select 1 from nonexistanttable;`),
Migrations: migration.Migrations{
2: migration.Migration{
Statements: []byte(`select 1 from nonexistanttable;`),
Edition: "oss",
Version: 2,
},
},
Priority: 0,
},
},
))
require.NoError(t, err)
assert.Error(t, m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
assert.Error(t, err)
state, err := m.CurrentState(ctx)
require.NoError(t, err)
@ -409,7 +714,8 @@ func Test_GetMigrationLog(t *testing.T) {
require.NoError(t, err)
m, err := schema.NewManager(ctx, schema.Dialect(dialect), d)
require.NoError(t, err)
require.NoError(t, m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
require.NoError(t, err)
const insert = `insert into log_migration(entry, edition) values ($1, $2)`
createEntries := func(entries ...string) {

@ -0,0 +1,44 @@
package migration
import (
"context"
"database/sql"
)
// Problems are reports of data issues that were identified by a CheckFunc.
type Problems []string
// Repairs are reports of changes made to data by a RepairFunc.
type Repairs []string
// CheckFunc is a function that checks the state of data in the database to
// determine if a migration will fail, and if so to report the data that is
// problematic so it can be fixed.
type CheckFunc func(context.Context, *sql.Tx) (Problems, error)
// RepairFunc is a function that alters data in the database to resolve issues
// that would prevent a migration from successfully running.
type RepairFunc func(context.Context, *sql.Tx) (Repairs, error)
// Hook provides a set of functions that allow for executing checks prior to
// executing migration statements.
type Hook struct {
CheckFunc CheckFunc
RepairFunc RepairFunc
// RepairDescription will describe what change running the repair function
// would perform.
RepairDescription string
}
// Migration is a set of statements that will alter the database structure or
// or data.
type Migration struct {
Statements []byte
Edition string
Version int
PreHook *Hook
}
// Migrations are a set of migrations by version.
type Migrations map[int]Migration

@ -6,12 +6,16 @@ import (
"embed"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/db/schema/internal/edition"
"github.com/hashicorp/boundary/internal/db/schema/migration"
)
// postgres contains the migrations sql files for postgres oss edition
//go:embed postgres
var postgres embed.FS
var prehooks = map[int]*migration.Hook{}
func init() {
schema.RegisterEdition("oss", schema.Postgres, postgres, 0)
schema.RegisterEdition("oss", schema.Postgres, postgres, 0, edition.WithPreHooks(prehooks))
}

@ -28,7 +28,8 @@ func TestApplyMigrations(t *testing.T) {
ctx := context.Background()
m, err := schema.NewManager(ctx, schema.Dialect(dialect), d)
require.NoError(t, err)
assert.NoError(t, m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
assert.NoError(t, err)
}
func TestApplyMigrations_NotFromFresh(t *testing.T) {
@ -48,7 +49,8 @@ func TestApplyMigrations_NotFromFresh(t *testing.T) {
schema.TestCreatePartialEditions(schema.Dialect(dialect), schema.PartialEditions{"oss": 1}),
))
require.NoError(t, err)
assert.NoError(t, m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
assert.NoError(t, err)
state, err := m.CurrentState(ctx)
require.NoError(t, err)
@ -69,7 +71,8 @@ func TestApplyMigrations_NotFromFresh(t *testing.T) {
schema.TestCreatePartialEditions(schema.Dialect(dialect), schema.PartialEditions{"oss": 3}),
))
require.NoError(t, err)
assert.NoError(t, newM.ApplyMigrations(ctx))
_, err = newM.ApplyMigrations(ctx)
assert.NoError(t, err)
state, err = newM.CurrentState(ctx)
require.NoError(t, err)
want = &schema.State{

@ -61,7 +61,8 @@ func Test_ServerEnumChanges(t *testing.T) {
))
require.NoError(err)
require.NoError(m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
require.NoError(err)
state, err := m.CurrentState(ctx)
require.NoError(err)
want := &schema.State{
@ -87,7 +88,8 @@ func Test_ServerEnumChanges(t *testing.T) {
))
require.NoError(err)
require.NoError(m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
require.NoError(err)
state, err = m.CurrentState(ctx)
require.NoError(err)

@ -51,7 +51,8 @@ func testSetupDb(ctx context.Context, t *testing.T) *sql.DB {
))
require.NoError(err)
require.NoError(m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
require.NoError(err)
state, err := m.CurrentState(ctx)
require.NoError(err)
want := &schema.State{

@ -36,7 +36,8 @@ func TestMigrations_WareHouse_HostAddresses(t *testing.T) {
))
require.NoError(t, err)
require.NoError(t, m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
require.NoError(t, err)
state, err := m.CurrentState(ctx)
require.NoError(t, err)
want := &schema.State{
@ -198,7 +199,8 @@ values
))
require.NoError(t, err)
require.NoError(t, m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
require.NoError(t, err)
state, err = m.CurrentState(ctx)
require.NoError(t, err)
want = &schema.State{

@ -39,7 +39,8 @@ func TestMigrations_KMS_Refactor(t *testing.T) {
))
require.NoError(t, err)
require.NoError(t, m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
require.NoError(t, err)
state, err := m.CurrentState(ctx)
require.NoError(t, err)
want := &schema.State{
@ -77,7 +78,8 @@ func TestMigrations_KMS_Refactor(t *testing.T) {
))
require.NoError(t, err)
require.NoError(t, m.ApplyMigrations(ctx))
_, err = m.ApplyMigrations(ctx)
require.NoError(t, err)
state, err = m.CurrentState(ctx)
require.NoError(t, err)
want = &schema.State{

@ -16,8 +16,9 @@ type Option func(*options)
// options = how options are represented
type options struct {
withEditions edition.Editions
withDeleteLog bool
withEditions edition.Editions
withDeleteLog bool
withRepairMigrations map[string]map[int]bool
}
func getDefaultOptions() options {
@ -37,3 +38,11 @@ func WithDeleteLog(del bool) Option {
o.withDeleteLog = del
}
}
// WithRepairMigrations provides an option to specify the set of migrations
// that should run their repair functions if there is a failure on a prehook check.
func WithRepairMigrations(r RepairMigrations) Option {
return func(o *options) {
o.withRepairMigrations = r
}
}

@ -0,0 +1,36 @@
package schema
import "github.com/hashicorp/boundary/internal/db/schema/migration"
// RepairMigrations is a set of migration versions grouped by edition that
// should have their coresponding repair functions run if the check function
// reports an error.
type RepairMigrations map[string]map[int]bool
// IsSet checks for the existence of the given edition and version.
func (r RepairMigrations) IsSet(edition string, version int) bool {
e, ok := r[edition]
if !ok {
return false
}
_, ok = e[version]
return ok
}
// Add adds the edition and version to the set.
func (r RepairMigrations) Add(edition string, version int) {
e, ok := r[edition]
if !ok {
e = make(map[int]bool)
}
e[version] = true
r[edition] = e
}
// RepairLog represents a log entry generated by a repair function.
type RepairLog struct {
Edition string
Version int
Entry migration.Repairs
}

@ -0,0 +1,134 @@
package schema_test
import (
"testing"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/stretchr/testify/assert"
)
func TestRepairMigrationsIsSet(t *testing.T) {
cases := []struct {
name string
m schema.RepairMigrations
edition string
version int
want bool
}{
{
name: "Set",
m: schema.RepairMigrations{
"one": {
1: true,
},
},
edition: "one",
version: 1,
want: true,
},
{
name: "VersionNotSet",
m: schema.RepairMigrations{
"one": {
2: true,
},
},
edition: "one",
version: 1,
want: false,
},
{
name: "EditionNotSet",
m: schema.RepairMigrations{
"two": {
1: true,
},
},
edition: "one",
version: 1,
want: false,
},
{
name: "Empty",
m: schema.RepairMigrations{},
edition: "one",
version: 1,
want: false,
},
{
name: "Nil",
m: nil,
edition: "one",
version: 1,
want: false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := tc.m.IsSet(tc.edition, tc.version)
assert.Equal(t, tc.want, got, tc.name)
})
}
}
func TestRepairMigrationsAdd(t *testing.T) {
cases := []struct {
name string
initial schema.RepairMigrations
edition string
version int
want schema.RepairMigrations
}{
{
name: "Empty",
initial: schema.RepairMigrations{},
edition: "one",
version: 1,
want: schema.RepairMigrations{
"one": {
1: true,
},
},
},
{
name: "AlreadySet",
initial: schema.RepairMigrations{
"one": {
1: true,
},
},
edition: "one",
version: 1,
want: schema.RepairMigrations{
"one": {
1: true,
},
},
},
{
name: "EditionExistsNewVersion",
initial: schema.RepairMigrations{
"one": {
1: true,
},
},
edition: "one",
version: 2,
want: schema.RepairMigrations{
"one": {
1: true,
2: true,
},
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
tc.initial.Add(tc.edition, tc.version)
got := tc.initial
assert.Equal(t, tc.want, got, tc.name)
})
}
}

@ -32,7 +32,7 @@ func MigrateStore(ctx context.Context, dialect Dialect, url string, opt ...Optio
return false, nil
}
if err := sMan.ApplyMigrations(ctx); err != nil {
if _, err := sMan.ApplyMigrations(ctx); err != nil {
return false, errors.Wrap(ctx, err, op)
}

@ -6,6 +6,7 @@ import (
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/db/schema/internal/edition"
"github.com/hashicorp/boundary/internal/db/schema/migration"
"github.com/hashicorp/boundary/testing/dbtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -27,8 +28,12 @@ func TestMigrateStore(t *testing.T) {
Name: "oss",
Dialect: schema.Postgres,
LatestVersion: 1,
Migrations: map[int][]byte{
1: []byte(`select 1`),
Migrations: migration.Migrations{
1: migration.Migration{
Statements: []byte(`select 1`),
Version: 1,
Edition: "oss",
},
},
Priority: 0,
},
@ -43,8 +48,12 @@ func TestMigrateStore(t *testing.T) {
Name: "oss",
Dialect: schema.Postgres,
LatestVersion: 1,
Migrations: map[int][]byte{
2: []byte(`select 1`),
Migrations: migration.Migrations{
2: migration.Migration{
Statements: []byte(`select 1`),
Version: 2,
Edition: "oss",
},
},
Priority: 0,
},
@ -59,9 +68,17 @@ func TestMigrateStore(t *testing.T) {
Name: "oss",
Dialect: schema.Postgres,
LatestVersion: 2,
Migrations: map[int][]byte{
1: []byte(`select 1`),
2: []byte(`select 1`),
Migrations: migration.Migrations{
1: migration.Migration{
Statements: []byte(`select 1`),
Version: 1,
Edition: "oss",
},
2: migration.Migration{
Statements: []byte(`select 1`),
Version: 2,
Edition: "oss",
},
},
Priority: 0,
},
@ -75,9 +92,17 @@ func TestMigrateStore(t *testing.T) {
Name: "oss",
Dialect: schema.Postgres,
LatestVersion: 2,
Migrations: map[int][]byte{
1: []byte(`select 1`),
2: []byte(`select 1`),
Migrations: migration.Migrations{
1: migration.Migration{
Statements: []byte(`select 1`),
Version: 1,
Edition: "oss",
},
2: migration.Migration{
Statements: []byte(`select 1`),
Version: 2,
Edition: "oss",
},
},
Priority: 0,
},

@ -0,0 +1,6 @@
begin;
create domain tt_public_id as text
check(
length(trim(value)) > 10
);
commit;

@ -0,0 +1,6 @@
begin;
create domain tt_public_id as text
check(
length(trim(value)) > 10
);
commit;

@ -0,0 +1,5 @@
begin;
create table test_four (
id tt_public_id primary key
);
commit;

@ -1,6 +1,9 @@
package schema
import "github.com/hashicorp/boundary/internal/db/schema/internal/edition"
import (
"github.com/hashicorp/boundary/internal/db/schema/internal/edition"
"github.com/hashicorp/boundary/internal/db/schema/migration"
)
// PartialEditions is used by TestCreatePartialEditions. It is a map of edition
// names to the max version that should be included.
@ -20,7 +23,7 @@ func TestCreatePartialEditions(dialect Dialect, p PartialEditions) edition.Editi
Dialect: ee.Dialect,
Priority: ee.Priority,
LatestVersion: nilVersion,
Migrations: make(map[int][]byte),
Migrations: make(migration.Migrations),
}
for k, b := range ee.Migrations {

@ -109,7 +109,8 @@ func testInitStore(t testing.TB, cleanup func() error, url string) {
require.NoError(t, err)
sm, err := schema.NewManager(ctx, schema.Dialect(dialect), d)
require.NoError(t, err)
require.NoError(t, sm.ApplyMigrations(ctx))
_, err = sm.ApplyMigrations(ctx)
require.NoError(t, err)
}
type constraintResults struct {

Loading…
Cancel
Save