diff --git a/internal/db/schema/internal/postgres/postgres.go b/internal/db/schema/internal/postgres/postgres.go index bf918c87fa..7351981d31 100644 --- a/internal/db/schema/internal/postgres/postgres.go +++ b/internal/db/schema/internal/postgres/postgres.go @@ -202,7 +202,7 @@ func (p *Postgres) CommitRun(ctx context.Context) error { return nil } -// RollbackRun rolls back a transaction. +// RollbackRun rolls back a transaction. If no transaction is active, it will return nil. func (p *Postgres) RollbackRun(ctx context.Context) error { const op = "postgres.(Postgres).RollbackRun" defer func() { diff --git a/internal/db/schema/internal/postgres/postgres_test.go b/internal/db/schema/internal/postgres/postgres_test.go index 9d9d20537c..985a6ab9eb 100644 --- a/internal/db/schema/internal/postgres/postgres_test.go +++ b/internal/db/schema/internal/postgres/postgres_test.go @@ -46,6 +46,38 @@ create table foo ( assert.Equal(t, v, 1001) } +func TestRun_Rollback(t *testing.T) { + ctx := context.Background() + p, _, _ := setup(ctx, t) + + statements := bytes.NewReader([]byte(` +create table foo ( + id bigint primary key, + bar text +); +`)) + + err := p.StartRun(ctx) + require.NoError(t, err) + + err = p.EnsureVersionTable(ctx) + require.NoError(t, err) + + err = p.EnsureMigrationLogTable(ctx) + require.NoError(t, err) + + err = p.Run(ctx, statements, 1001, "oss") + require.NoError(t, err) + + err = p.RollbackRun(ctx) + require.NoError(t, err) + + v, i, err := p.CurrentState(ctx, "oss") + require.NoError(t, err) + assert.False(t, i) + assert.Equal(t, -1, v) +} + func TestRun_NoTxn(t *testing.T) { ctx := context.Background() p, _, _ := setup(ctx, t) diff --git a/internal/db/schema/manager.go b/internal/db/schema/manager.go index a907ff2079..9f2426085a 100644 --- a/internal/db/schema/manager.go +++ b/internal/db/schema/manager.go @@ -33,7 +33,7 @@ type driver interface { StartRun(context.Context) error // CommitRun commits a transaction, if there is an error it should rollback the transaction. CommitRun(context.Context) error - // RollbackRun rolls back a transaction. + // RollbackRun rolls back a transaction. If no transaction is active, it should return nil. RollbackRun(context.Context) error // CheckHook is a hook that runs prior to a migration's statements. // It should run in the same transaction a corresponding Run call. @@ -247,42 +247,31 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re const op = "schema.(Manager).runMigrations" var logEntries []RepairLog - var errFinal error + var retErr error if startErr := b.driver.StartRun(ctx); startErr != nil { - errFinal = errors.Wrap(ctx, startErr, op) - return nil, errFinal + return nil, errors.Wrap(ctx, startErr, op) } defer func() { - if errFinal != nil { - errRollback := b.driver.RollbackRun(ctx) - if errRollback != nil { - errFinal = stderrors.Join(errFinal, errRollback) - } - errFinal = errors.Wrap(ctx, errFinal, op) - return - } - if commitErr := b.driver.CommitRun(ctx); commitErr != nil { - errFinal = errors.Wrap(ctx, commitErr, op) + // rolling back a committed run is a no-op, so we can safely call this every time + if err := b.driver.RollbackRun(ctx); err != nil { + retErr = stderrors.Join(retErr, err) } }() if ensureErr := b.driver.EnsureVersionTable(ctx); ensureErr != nil { - errFinal = errors.Wrap(ctx, ensureErr, op) - return nil, errFinal + return nil, errors.Wrap(ctx, ensureErr, op) } if ensureErr := b.driver.EnsureMigrationLogTable(ctx); ensureErr != nil { - errFinal = errors.Wrap(ctx, ensureErr, op) - return nil, errFinal + return nil, errors.Wrap(ctx, ensureErr, op) } for p.Next() { select { case <-ctx.Done(): - errFinal = errors.Wrap(ctx, ctx.Err(), op) - return nil, errFinal + return nil, errors.Wrap(ctx, ctx.Err(), op) default: // context is not done yet. Continue on to the next query to execute. } @@ -290,25 +279,22 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re if h := p.PreHook(); h != nil { problems, err := b.driver.CheckHook(ctx, h.CheckFunc) if err != nil { - errFinal = errors.Wrap(ctx, err, op) - return nil, errFinal + return nil, errors.Wrap(ctx, err, op) } if len(problems) > 0 { if !b.selectedRepairs.IsSet(p.Edition(), p.Version()) { - errFinal = MigrationCheckError{ + return nil, MigrationCheckError{ Version: p.Version(), Edition: p.Edition(), Problems: problems, RepairDescription: h.RepairDescription, } - return nil, errFinal } repairs, err := b.driver.RepairHook(ctx, h.RepairFunc) if err != nil { - errFinal = errors.Wrap(ctx, err, op) - return nil, errFinal + return nil, errors.Wrap(ctx, err, op) } logEntries = append(logEntries, RepairLog{ Version: p.Version(), @@ -318,10 +304,13 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re } } if runErr := b.driver.Run(ctx, bytes.NewReader(p.Statements()), p.Version(), p.Edition()); runErr != nil { - errFinal = errors.Wrap(ctx, runErr, op) - return nil, errFinal + return nil, errors.Wrap(ctx, runErr, op) } } - return logEntries, nil + if err := b.driver.CommitRun(ctx); err != nil { + return nil, errors.Wrap(ctx, err, op) + } + + return logEntries, retErr } diff --git a/internal/db/schema/manager_test.go b/internal/db/schema/manager_test.go index b66544432c..a2340c6c73 100644 --- a/internal/db/schema/manager_test.go +++ b/internal/db/schema/manager_test.go @@ -360,8 +360,8 @@ func TestApplyMigrationWithHooks(t *testing.T) { Editions: []schema.EditionState{ { Name: "hooks", - BinarySchemaVersion: 1001, - DatabaseSchemaVersion: 1001, + BinarySchemaVersion: 2001, + DatabaseSchemaVersion: 2001, DatabaseSchemaState: schema.Equal, }, }, @@ -379,7 +379,7 @@ func TestApplyMigrationWithHooks(t *testing.T) { 0, edition.WithPreHooks( map[int]*migration.Hook{ - 1001: { + 2001: { CheckFunc: func(ctx context.Context, tx *sql.Tx) (migration.Problems, error) { return migration.Problems{"failed"}, nil }, @@ -393,7 +393,7 @@ func TestApplyMigrationWithHooks(t *testing.T) { }, nil, schema.MigrationCheckError{ - Version: 1001, + Version: 2001, Edition: "hooks", Problems: migration.Problems{"failed"}, RepairDescription: "repair all the things", @@ -403,7 +403,7 @@ func TestApplyMigrationWithHooks(t *testing.T) { Editions: []schema.EditionState{ { Name: "hooks", - BinarySchemaVersion: 1001, + BinarySchemaVersion: 2001, DatabaseSchemaVersion: 1, DatabaseSchemaState: schema.Behind, }, @@ -447,8 +447,8 @@ func TestApplyMigrationWithHooks(t *testing.T) { Editions: []schema.EditionState{ { Name: "hooks", - BinarySchemaVersion: 1001, - DatabaseSchemaVersion: 1001, + BinarySchemaVersion: 2001, + DatabaseSchemaVersion: 2001, DatabaseSchemaState: schema.Equal, }, }, @@ -494,7 +494,7 @@ func TestApplyMigrationWithHooks(t *testing.T) { Editions: []schema.EditionState{ { Name: "hooks", - BinarySchemaVersion: 1001, + BinarySchemaVersion: 2001, DatabaseSchemaVersion: 1, DatabaseSchemaState: schema.Behind, }, diff --git a/internal/db/schema/testdata/hooks/updated/2/01_another_update.up.sql b/internal/db/schema/testdata/hooks/updated/2/01_another_update.up.sql new file mode 100644 index 0000000000..efed6477ec --- /dev/null +++ b/internal/db/schema/testdata/hooks/updated/2/01_another_update.up.sql @@ -0,0 +1,8 @@ +-- Copyright (c) HashiCorp, Inc. +-- SPDX-License-Identifier: BUSL-1.1 + +begin; + create table test_five ( + id tt_public_id primary key + ); +commit;