Allow rollbacks of migrations if hooks fail (#5904)

pull/5912/head
Michael Milton 10 months ago committed by GitHub
parent 1b24c268fb
commit ca899f4fff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -202,6 +202,24 @@ func (p *Postgres) CommitRun(ctx context.Context) error {
return nil
}
// RollbackRun rolls back a transaction.
func (p *Postgres) RollbackRun(ctx context.Context) error {
const op = "postgres.(Postgres).RollbackRun"
defer func() {
p.tx = nil
}()
if p.tx == nil {
return errors.New(ctx, errors.MigrationIntegrity, op, "no pending transaction")
}
if err := p.tx.Rollback(); err != nil {
if errors.Is(err, sql.ErrTxDone) {
return nil
}
return errors.Wrap(ctx, err, op)
}
return nil
}
// Run will apply a migration. The io.Reader should provide the SQL
// statements to execute, and the int is the version for that set of
// statements. This should always be wrapped by StartRun and CommitRun.

@ -7,6 +7,7 @@ import (
"bytes"
"context"
"database/sql"
stderrors "errors"
"fmt"
"io"
"sync"
@ -32,6 +33,8 @@ type driver interface {
StartRun(context.Context) error
// CommitRun commits a transaction, if there is an error it should rollback the transaction.
CommitRun(context.Context) error
// RollbackRun rolls back a transaction.
RollbackRun(context.Context) error
// CheckHook is a hook that runs prior to a migration's statements.
// It should run in the same transaction a corresponding Run call.
CheckHook(context.Context, migration.CheckFunc) (migration.Problems, error)
@ -244,34 +247,42 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re
const op = "schema.(Manager).runMigrations"
var logEntries []RepairLog
var err error
var errFinal error
if startErr := b.driver.StartRun(ctx); startErr != nil {
err = errors.Wrap(ctx, startErr, op)
return nil, err
errFinal = errors.Wrap(ctx, startErr, op)
return nil, errFinal
}
defer func() {
if errFinal != nil {
errRollback := b.driver.RollbackRun(ctx)
if errRollback != nil {
errFinal = stderrors.Join(errFinal, errRollback)
}
errFinal = errors.Wrap(ctx, errFinal, op)
return
}
if commitErr := b.driver.CommitRun(ctx); commitErr != nil {
err = errors.Wrap(ctx, commitErr, op)
errFinal = errors.Wrap(ctx, commitErr, op)
}
}()
if ensureErr := b.driver.EnsureVersionTable(ctx); ensureErr != nil {
err = errors.Wrap(ctx, ensureErr, op)
return nil, err
errFinal = errors.Wrap(ctx, ensureErr, op)
return nil, errFinal
}
if ensureErr := b.driver.EnsureMigrationLogTable(ctx); ensureErr != nil {
err = errors.Wrap(ctx, ensureErr, op)
return nil, err
errFinal = errors.Wrap(ctx, ensureErr, op)
return nil, errFinal
}
for p.Next() {
select {
case <-ctx.Done():
err = errors.Wrap(ctx, ctx.Err(), op)
return nil, err
errFinal = errors.Wrap(ctx, ctx.Err(), op)
return nil, errFinal
default:
// context is not done yet. Continue on to the next query to execute.
}
@ -279,22 +290,25 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re
if h := p.PreHook(); h != nil {
problems, err := b.driver.CheckHook(ctx, h.CheckFunc)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
errFinal = errors.Wrap(ctx, err, op)
return nil, errFinal
}
if len(problems) > 0 {
if !b.selectedRepairs.IsSet(p.Edition(), p.Version()) {
return nil, MigrationCheckError{
errFinal = MigrationCheckError{
Version: p.Version(),
Edition: p.Edition(),
Problems: problems,
RepairDescription: h.RepairDescription,
}
return nil, errFinal
}
repairs, err := b.driver.RepairHook(ctx, h.RepairFunc)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
errFinal = errors.Wrap(ctx, err, op)
return nil, errFinal
}
logEntries = append(logEntries, RepairLog{
Version: p.Version(),
@ -304,8 +318,8 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re
}
}
if runErr := b.driver.Run(ctx, bytes.NewReader(p.Statements()), p.Version(), p.Edition()); runErr != nil {
err = errors.Wrap(ctx, runErr, op)
return nil, err
errFinal = errors.Wrap(ctx, runErr, op)
return nil, errFinal
}
}

@ -10,6 +10,7 @@ import (
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/common"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/db/schema/internal/postgres"
"github.com/hashicorp/boundary/internal/db/schema/migration"
"github.com/hashicorp/boundary/internal/db/schema/migrations/oss/internal/hook97001"
"github.com/hashicorp/boundary/testing/dbtest"
@ -49,7 +50,8 @@ import (
// - this cannot be created
func TestMigrationHook97001(t *testing.T) {
const (
priorMigration = 96001
priorMigration = 95001
latestMigration = 97005
)
dialect := dbtest.Postgres
ctx := context.Background()
@ -159,6 +161,19 @@ func TestMigrationHook97001(t *testing.T) {
_, err = rw.Exec(ctx, query, nil)
require.NoError(t, err)
// migrate to latest - make sure it fails
// migration to the prior migration (before the one we want to test)
latestm, err := schema.NewManager(ctx, schema.Dialect(dialect), d)
require.NoError(t, err)
_, err = latestm.ApplyMigrations(ctx)
require.Error(t, err)
driver, err := postgres.New(ctx, d)
require.NoError(t, err)
schemaVer, _, err := driver.CurrentState(ctx, "oss")
require.NoError(t, err)
require.Equal(t, priorMigration, schemaVer)
tx, err := d.BeginTx(ctx, nil)
require.NoError(t, err)

Loading…
Cancel
Save