diff --git a/internal/db/schema/internal/postgres/postgres.go b/internal/db/schema/internal/postgres/postgres.go index 2c42b4cc58..bf918c87fa 100644 --- a/internal/db/schema/internal/postgres/postgres.go +++ b/internal/db/schema/internal/postgres/postgres.go @@ -202,6 +202,24 @@ func (p *Postgres) CommitRun(ctx context.Context) error { return nil } +// RollbackRun rolls back a transaction. +func (p *Postgres) RollbackRun(ctx context.Context) error { + const op = "postgres.(Postgres).RollbackRun" + defer func() { + p.tx = nil + }() + if p.tx == nil { + return errors.New(ctx, errors.MigrationIntegrity, op, "no pending transaction") + } + if err := p.tx.Rollback(); err != nil { + if errors.Is(err, sql.ErrTxDone) { + return nil + } + return errors.Wrap(ctx, err, op) + } + return nil +} + // Run will apply a migration. The io.Reader should provide the SQL // statements to execute, and the int is the version for that set of // statements. This should always be wrapped by StartRun and CommitRun. diff --git a/internal/db/schema/manager.go b/internal/db/schema/manager.go index fd87ba4e22..a907ff2079 100644 --- a/internal/db/schema/manager.go +++ b/internal/db/schema/manager.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "database/sql" + stderrors "errors" "fmt" "io" "sync" @@ -32,6 +33,8 @@ type driver interface { StartRun(context.Context) error // CommitRun commits a transaction, if there is an error it should rollback the transaction. CommitRun(context.Context) error + // RollbackRun rolls back a transaction. + RollbackRun(context.Context) error // CheckHook is a hook that runs prior to a migration's statements. // It should run in the same transaction a corresponding Run call. CheckHook(context.Context, migration.CheckFunc) (migration.Problems, error) @@ -244,34 +247,42 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re const op = "schema.(Manager).runMigrations" var logEntries []RepairLog - var err error + var errFinal error if startErr := b.driver.StartRun(ctx); startErr != nil { - err = errors.Wrap(ctx, startErr, op) - return nil, err + errFinal = errors.Wrap(ctx, startErr, op) + return nil, errFinal } defer func() { + if errFinal != nil { + errRollback := b.driver.RollbackRun(ctx) + if errRollback != nil { + errFinal = stderrors.Join(errFinal, errRollback) + } + errFinal = errors.Wrap(ctx, errFinal, op) + return + } if commitErr := b.driver.CommitRun(ctx); commitErr != nil { - err = errors.Wrap(ctx, commitErr, op) + errFinal = errors.Wrap(ctx, commitErr, op) } }() if ensureErr := b.driver.EnsureVersionTable(ctx); ensureErr != nil { - err = errors.Wrap(ctx, ensureErr, op) - return nil, err + errFinal = errors.Wrap(ctx, ensureErr, op) + return nil, errFinal } if ensureErr := b.driver.EnsureMigrationLogTable(ctx); ensureErr != nil { - err = errors.Wrap(ctx, ensureErr, op) - return nil, err + errFinal = errors.Wrap(ctx, ensureErr, op) + return nil, errFinal } for p.Next() { select { case <-ctx.Done(): - err = errors.Wrap(ctx, ctx.Err(), op) - return nil, err + errFinal = errors.Wrap(ctx, ctx.Err(), op) + return nil, errFinal default: // context is not done yet. Continue on to the next query to execute. } @@ -279,22 +290,25 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re if h := p.PreHook(); h != nil { problems, err := b.driver.CheckHook(ctx, h.CheckFunc) if err != nil { - return nil, errors.Wrap(ctx, err, op) + errFinal = errors.Wrap(ctx, err, op) + return nil, errFinal } if len(problems) > 0 { if !b.selectedRepairs.IsSet(p.Edition(), p.Version()) { - return nil, MigrationCheckError{ + errFinal = MigrationCheckError{ Version: p.Version(), Edition: p.Edition(), Problems: problems, RepairDescription: h.RepairDescription, } + return nil, errFinal } repairs, err := b.driver.RepairHook(ctx, h.RepairFunc) if err != nil { - return nil, errors.Wrap(ctx, err, op) + errFinal = errors.Wrap(ctx, err, op) + return nil, errFinal } logEntries = append(logEntries, RepairLog{ Version: p.Version(), @@ -304,8 +318,8 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re } } if runErr := b.driver.Run(ctx, bytes.NewReader(p.Statements()), p.Version(), p.Edition()); runErr != nil { - err = errors.Wrap(ctx, runErr, op) - return nil, err + errFinal = errors.Wrap(ctx, runErr, op) + return nil, errFinal } } diff --git a/internal/db/schema/migrations/oss/postgres_97_01_test.go b/internal/db/schema/migrations/oss/postgres_97_01_test.go index 76d30fdffe..e82ffb5a09 100644 --- a/internal/db/schema/migrations/oss/postgres_97_01_test.go +++ b/internal/db/schema/migrations/oss/postgres_97_01_test.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/common" "github.com/hashicorp/boundary/internal/db/schema" + "github.com/hashicorp/boundary/internal/db/schema/internal/postgres" "github.com/hashicorp/boundary/internal/db/schema/migration" "github.com/hashicorp/boundary/internal/db/schema/migrations/oss/internal/hook97001" "github.com/hashicorp/boundary/testing/dbtest" @@ -49,7 +50,8 @@ import ( // - this cannot be created func TestMigrationHook97001(t *testing.T) { const ( - priorMigration = 96001 + priorMigration = 95001 + latestMigration = 97005 ) dialect := dbtest.Postgres ctx := context.Background() @@ -159,6 +161,19 @@ func TestMigrationHook97001(t *testing.T) { _, err = rw.Exec(ctx, query, nil) require.NoError(t, err) + // migrate to latest - make sure it fails + // migration to the prior migration (before the one we want to test) + latestm, err := schema.NewManager(ctx, schema.Dialect(dialect), d) + require.NoError(t, err) + _, err = latestm.ApplyMigrations(ctx) + require.Error(t, err) + + driver, err := postgres.New(ctx, d) + require.NoError(t, err) + schemaVer, _, err := driver.CurrentState(ctx, "oss") + require.NoError(t, err) + require.Equal(t, priorMigration, schemaVer) + tx, err := d.BeginTx(ctx, nil) require.NoError(t, err)