fix(bug): database migration does not rollback schema version on hook failure (#5908)

* add tests

* correct expect schema version

* use conventional named return

* add rollback test and more description in godoc

* refactor to use commit at the end and defer rollback

* add comment

* lint
pull/5912/head
Sorawis Nilparuk (Bo) 8 months ago committed by GitHub
parent 8b8fb28228
commit d32245a33e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -202,7 +202,7 @@ func (p *Postgres) CommitRun(ctx context.Context) error {
return nil
}
// RollbackRun rolls back a transaction.
// RollbackRun rolls back a transaction. If no transaction is active, it will return nil.
func (p *Postgres) RollbackRun(ctx context.Context) error {
const op = "postgres.(Postgres).RollbackRun"
defer func() {

@ -46,6 +46,38 @@ create table foo (
assert.Equal(t, v, 1001)
}
func TestRun_Rollback(t *testing.T) {
ctx := context.Background()
p, _, _ := setup(ctx, t)
statements := bytes.NewReader([]byte(`
create table foo (
id bigint primary key,
bar text
);
`))
err := p.StartRun(ctx)
require.NoError(t, err)
err = p.EnsureVersionTable(ctx)
require.NoError(t, err)
err = p.EnsureMigrationLogTable(ctx)
require.NoError(t, err)
err = p.Run(ctx, statements, 1001, "oss")
require.NoError(t, err)
err = p.RollbackRun(ctx)
require.NoError(t, err)
v, i, err := p.CurrentState(ctx, "oss")
require.NoError(t, err)
assert.False(t, i)
assert.Equal(t, -1, v)
}
func TestRun_NoTxn(t *testing.T) {
ctx := context.Background()
p, _, _ := setup(ctx, t)

@ -33,7 +33,7 @@ type driver interface {
StartRun(context.Context) error
// CommitRun commits a transaction, if there is an error it should rollback the transaction.
CommitRun(context.Context) error
// RollbackRun rolls back a transaction.
// RollbackRun rolls back a transaction. If no transaction is active, it should return nil.
RollbackRun(context.Context) error
// CheckHook is a hook that runs prior to a migration's statements.
// It should run in the same transaction a corresponding Run call.
@ -247,42 +247,31 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re
const op = "schema.(Manager).runMigrations"
var logEntries []RepairLog
var errFinal error
var retErr error
if startErr := b.driver.StartRun(ctx); startErr != nil {
errFinal = errors.Wrap(ctx, startErr, op)
return nil, errFinal
return nil, errors.Wrap(ctx, startErr, op)
}
defer func() {
if errFinal != nil {
errRollback := b.driver.RollbackRun(ctx)
if errRollback != nil {
errFinal = stderrors.Join(errFinal, errRollback)
}
errFinal = errors.Wrap(ctx, errFinal, op)
return
}
if commitErr := b.driver.CommitRun(ctx); commitErr != nil {
errFinal = errors.Wrap(ctx, commitErr, op)
// rolling back a committed run is a no-op, so we can safely call this every time
if err := b.driver.RollbackRun(ctx); err != nil {
retErr = stderrors.Join(retErr, err)
}
}()
if ensureErr := b.driver.EnsureVersionTable(ctx); ensureErr != nil {
errFinal = errors.Wrap(ctx, ensureErr, op)
return nil, errFinal
return nil, errors.Wrap(ctx, ensureErr, op)
}
if ensureErr := b.driver.EnsureMigrationLogTable(ctx); ensureErr != nil {
errFinal = errors.Wrap(ctx, ensureErr, op)
return nil, errFinal
return nil, errors.Wrap(ctx, ensureErr, op)
}
for p.Next() {
select {
case <-ctx.Done():
errFinal = errors.Wrap(ctx, ctx.Err(), op)
return nil, errFinal
return nil, errors.Wrap(ctx, ctx.Err(), op)
default:
// context is not done yet. Continue on to the next query to execute.
}
@ -290,25 +279,22 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re
if h := p.PreHook(); h != nil {
problems, err := b.driver.CheckHook(ctx, h.CheckFunc)
if err != nil {
errFinal = errors.Wrap(ctx, err, op)
return nil, errFinal
return nil, errors.Wrap(ctx, err, op)
}
if len(problems) > 0 {
if !b.selectedRepairs.IsSet(p.Edition(), p.Version()) {
errFinal = MigrationCheckError{
return nil, MigrationCheckError{
Version: p.Version(),
Edition: p.Edition(),
Problems: problems,
RepairDescription: h.RepairDescription,
}
return nil, errFinal
}
repairs, err := b.driver.RepairHook(ctx, h.RepairFunc)
if err != nil {
errFinal = errors.Wrap(ctx, err, op)
return nil, errFinal
return nil, errors.Wrap(ctx, err, op)
}
logEntries = append(logEntries, RepairLog{
Version: p.Version(),
@ -318,10 +304,13 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re
}
}
if runErr := b.driver.Run(ctx, bytes.NewReader(p.Statements()), p.Version(), p.Edition()); runErr != nil {
errFinal = errors.Wrap(ctx, runErr, op)
return nil, errFinal
return nil, errors.Wrap(ctx, runErr, op)
}
}
return logEntries, nil
if err := b.driver.CommitRun(ctx); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return logEntries, retErr
}

@ -360,8 +360,8 @@ func TestApplyMigrationWithHooks(t *testing.T) {
Editions: []schema.EditionState{
{
Name: "hooks",
BinarySchemaVersion: 1001,
DatabaseSchemaVersion: 1001,
BinarySchemaVersion: 2001,
DatabaseSchemaVersion: 2001,
DatabaseSchemaState: schema.Equal,
},
},
@ -379,7 +379,7 @@ func TestApplyMigrationWithHooks(t *testing.T) {
0,
edition.WithPreHooks(
map[int]*migration.Hook{
1001: {
2001: {
CheckFunc: func(ctx context.Context, tx *sql.Tx) (migration.Problems, error) {
return migration.Problems{"failed"}, nil
},
@ -393,7 +393,7 @@ func TestApplyMigrationWithHooks(t *testing.T) {
},
nil,
schema.MigrationCheckError{
Version: 1001,
Version: 2001,
Edition: "hooks",
Problems: migration.Problems{"failed"},
RepairDescription: "repair all the things",
@ -403,7 +403,7 @@ func TestApplyMigrationWithHooks(t *testing.T) {
Editions: []schema.EditionState{
{
Name: "hooks",
BinarySchemaVersion: 1001,
BinarySchemaVersion: 2001,
DatabaseSchemaVersion: 1,
DatabaseSchemaState: schema.Behind,
},
@ -447,8 +447,8 @@ func TestApplyMigrationWithHooks(t *testing.T) {
Editions: []schema.EditionState{
{
Name: "hooks",
BinarySchemaVersion: 1001,
DatabaseSchemaVersion: 1001,
BinarySchemaVersion: 2001,
DatabaseSchemaVersion: 2001,
DatabaseSchemaState: schema.Equal,
},
},
@ -494,7 +494,7 @@ func TestApplyMigrationWithHooks(t *testing.T) {
Editions: []schema.EditionState{
{
Name: "hooks",
BinarySchemaVersion: 1001,
BinarySchemaVersion: 2001,
DatabaseSchemaVersion: 1,
DatabaseSchemaState: schema.Behind,
},

@ -0,0 +1,8 @@
-- Copyright (c) HashiCorp, Inc.
-- SPDX-License-Identifier: BUSL-1.1
begin;
create table test_five (
id tt_public_id primary key
);
commit;
Loading…
Cancel
Save