fix(bug): database migration does not rollback schema version on hook failure (#5908)

* add tests

* correct expect schema version

* use conventional named return

* add rollback test and more description in godoc

* refactor to use commit at the end and defer rollback

* add comment

* lint
pull/5912/head
Sorawis Nilparuk (Bo) 10 months ago committed by GitHub
parent 8b8fb28228
commit d32245a33e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -202,7 +202,7 @@ func (p *Postgres) CommitRun(ctx context.Context) error {
return nil return nil
} }
// RollbackRun rolls back a transaction. // RollbackRun rolls back a transaction. If no transaction is active, it will return nil.
func (p *Postgres) RollbackRun(ctx context.Context) error { func (p *Postgres) RollbackRun(ctx context.Context) error {
const op = "postgres.(Postgres).RollbackRun" const op = "postgres.(Postgres).RollbackRun"
defer func() { defer func() {

@ -46,6 +46,38 @@ create table foo (
assert.Equal(t, v, 1001) assert.Equal(t, v, 1001)
} }
func TestRun_Rollback(t *testing.T) {
ctx := context.Background()
p, _, _ := setup(ctx, t)
statements := bytes.NewReader([]byte(`
create table foo (
id bigint primary key,
bar text
);
`))
err := p.StartRun(ctx)
require.NoError(t, err)
err = p.EnsureVersionTable(ctx)
require.NoError(t, err)
err = p.EnsureMigrationLogTable(ctx)
require.NoError(t, err)
err = p.Run(ctx, statements, 1001, "oss")
require.NoError(t, err)
err = p.RollbackRun(ctx)
require.NoError(t, err)
v, i, err := p.CurrentState(ctx, "oss")
require.NoError(t, err)
assert.False(t, i)
assert.Equal(t, -1, v)
}
func TestRun_NoTxn(t *testing.T) { func TestRun_NoTxn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
p, _, _ := setup(ctx, t) p, _, _ := setup(ctx, t)

@ -33,7 +33,7 @@ type driver interface {
StartRun(context.Context) error StartRun(context.Context) error
// CommitRun commits a transaction, if there is an error it should rollback the transaction. // CommitRun commits a transaction, if there is an error it should rollback the transaction.
CommitRun(context.Context) error CommitRun(context.Context) error
// RollbackRun rolls back a transaction. // RollbackRun rolls back a transaction. If no transaction is active, it should return nil.
RollbackRun(context.Context) error RollbackRun(context.Context) error
// CheckHook is a hook that runs prior to a migration's statements. // CheckHook is a hook that runs prior to a migration's statements.
// It should run in the same transaction a corresponding Run call. // It should run in the same transaction a corresponding Run call.
@ -247,42 +247,31 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re
const op = "schema.(Manager).runMigrations" const op = "schema.(Manager).runMigrations"
var logEntries []RepairLog var logEntries []RepairLog
var errFinal error var retErr error
if startErr := b.driver.StartRun(ctx); startErr != nil { if startErr := b.driver.StartRun(ctx); startErr != nil {
errFinal = errors.Wrap(ctx, startErr, op) return nil, errors.Wrap(ctx, startErr, op)
return nil, errFinal
} }
defer func() { defer func() {
if errFinal != nil { // rolling back a committed run is a no-op, so we can safely call this every time
errRollback := b.driver.RollbackRun(ctx) if err := b.driver.RollbackRun(ctx); err != nil {
if errRollback != nil { retErr = stderrors.Join(retErr, err)
errFinal = stderrors.Join(errFinal, errRollback)
}
errFinal = errors.Wrap(ctx, errFinal, op)
return
}
if commitErr := b.driver.CommitRun(ctx); commitErr != nil {
errFinal = errors.Wrap(ctx, commitErr, op)
} }
}() }()
if ensureErr := b.driver.EnsureVersionTable(ctx); ensureErr != nil { if ensureErr := b.driver.EnsureVersionTable(ctx); ensureErr != nil {
errFinal = errors.Wrap(ctx, ensureErr, op) return nil, errors.Wrap(ctx, ensureErr, op)
return nil, errFinal
} }
if ensureErr := b.driver.EnsureMigrationLogTable(ctx); ensureErr != nil { if ensureErr := b.driver.EnsureMigrationLogTable(ctx); ensureErr != nil {
errFinal = errors.Wrap(ctx, ensureErr, op) return nil, errors.Wrap(ctx, ensureErr, op)
return nil, errFinal
} }
for p.Next() { for p.Next() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
errFinal = errors.Wrap(ctx, ctx.Err(), op) return nil, errors.Wrap(ctx, ctx.Err(), op)
return nil, errFinal
default: default:
// context is not done yet. Continue on to the next query to execute. // context is not done yet. Continue on to the next query to execute.
} }
@ -290,25 +279,22 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re
if h := p.PreHook(); h != nil { if h := p.PreHook(); h != nil {
problems, err := b.driver.CheckHook(ctx, h.CheckFunc) problems, err := b.driver.CheckHook(ctx, h.CheckFunc)
if err != nil { if err != nil {
errFinal = errors.Wrap(ctx, err, op) return nil, errors.Wrap(ctx, err, op)
return nil, errFinal
} }
if len(problems) > 0 { if len(problems) > 0 {
if !b.selectedRepairs.IsSet(p.Edition(), p.Version()) { if !b.selectedRepairs.IsSet(p.Edition(), p.Version()) {
errFinal = MigrationCheckError{ return nil, MigrationCheckError{
Version: p.Version(), Version: p.Version(),
Edition: p.Edition(), Edition: p.Edition(),
Problems: problems, Problems: problems,
RepairDescription: h.RepairDescription, RepairDescription: h.RepairDescription,
} }
return nil, errFinal
} }
repairs, err := b.driver.RepairHook(ctx, h.RepairFunc) repairs, err := b.driver.RepairHook(ctx, h.RepairFunc)
if err != nil { if err != nil {
errFinal = errors.Wrap(ctx, err, op) return nil, errors.Wrap(ctx, err, op)
return nil, errFinal
} }
logEntries = append(logEntries, RepairLog{ logEntries = append(logEntries, RepairLog{
Version: p.Version(), Version: p.Version(),
@ -318,10 +304,13 @@ func (b *Manager) runMigrations(ctx context.Context, p *provider.Provider) ([]Re
} }
} }
if runErr := b.driver.Run(ctx, bytes.NewReader(p.Statements()), p.Version(), p.Edition()); runErr != nil { if runErr := b.driver.Run(ctx, bytes.NewReader(p.Statements()), p.Version(), p.Edition()); runErr != nil {
errFinal = errors.Wrap(ctx, runErr, op) return nil, errors.Wrap(ctx, runErr, op)
return nil, errFinal
} }
} }
return logEntries, nil if err := b.driver.CommitRun(ctx); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return logEntries, retErr
} }

@ -360,8 +360,8 @@ func TestApplyMigrationWithHooks(t *testing.T) {
Editions: []schema.EditionState{ Editions: []schema.EditionState{
{ {
Name: "hooks", Name: "hooks",
BinarySchemaVersion: 1001, BinarySchemaVersion: 2001,
DatabaseSchemaVersion: 1001, DatabaseSchemaVersion: 2001,
DatabaseSchemaState: schema.Equal, DatabaseSchemaState: schema.Equal,
}, },
}, },
@ -379,7 +379,7 @@ func TestApplyMigrationWithHooks(t *testing.T) {
0, 0,
edition.WithPreHooks( edition.WithPreHooks(
map[int]*migration.Hook{ map[int]*migration.Hook{
1001: { 2001: {
CheckFunc: func(ctx context.Context, tx *sql.Tx) (migration.Problems, error) { CheckFunc: func(ctx context.Context, tx *sql.Tx) (migration.Problems, error) {
return migration.Problems{"failed"}, nil return migration.Problems{"failed"}, nil
}, },
@ -393,7 +393,7 @@ func TestApplyMigrationWithHooks(t *testing.T) {
}, },
nil, nil,
schema.MigrationCheckError{ schema.MigrationCheckError{
Version: 1001, Version: 2001,
Edition: "hooks", Edition: "hooks",
Problems: migration.Problems{"failed"}, Problems: migration.Problems{"failed"},
RepairDescription: "repair all the things", RepairDescription: "repair all the things",
@ -403,7 +403,7 @@ func TestApplyMigrationWithHooks(t *testing.T) {
Editions: []schema.EditionState{ Editions: []schema.EditionState{
{ {
Name: "hooks", Name: "hooks",
BinarySchemaVersion: 1001, BinarySchemaVersion: 2001,
DatabaseSchemaVersion: 1, DatabaseSchemaVersion: 1,
DatabaseSchemaState: schema.Behind, DatabaseSchemaState: schema.Behind,
}, },
@ -447,8 +447,8 @@ func TestApplyMigrationWithHooks(t *testing.T) {
Editions: []schema.EditionState{ Editions: []schema.EditionState{
{ {
Name: "hooks", Name: "hooks",
BinarySchemaVersion: 1001, BinarySchemaVersion: 2001,
DatabaseSchemaVersion: 1001, DatabaseSchemaVersion: 2001,
DatabaseSchemaState: schema.Equal, DatabaseSchemaState: schema.Equal,
}, },
}, },
@ -494,7 +494,7 @@ func TestApplyMigrationWithHooks(t *testing.T) {
Editions: []schema.EditionState{ Editions: []schema.EditionState{
{ {
Name: "hooks", Name: "hooks",
BinarySchemaVersion: 1001, BinarySchemaVersion: 2001,
DatabaseSchemaVersion: 1, DatabaseSchemaVersion: 1,
DatabaseSchemaState: schema.Behind, DatabaseSchemaState: schema.Behind,
}, },

@ -0,0 +1,8 @@
-- Copyright (c) HashiCorp, Inc.
-- SPDX-License-Identifier: BUSL-1.1
begin;
create table test_five (
id tt_public_id primary key
);
commit;
Loading…
Cancel
Save