Support upgrading from old migration table name (#889)

pull/893/head
Todd Knight 5 years ago committed by GitHub
parent 3461dcbd0c
commit 955ae41215
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,7 +15,7 @@ import (
// It owns the reporting to the UI any errors.
// Returns a cleanup function which must be called even if an error is returned and
// an error code where a non-zero value indicates an error happened.
func migrateDatabase(ctx context.Context, ui cli.Ui, dialect, u string) (func(), int) {
func migrateDatabase(ctx context.Context, ui cli.Ui, dialect, u string, requireFresh bool) (func(), int) {
noop := func() {}
// This database is used to keep an exclusive lock on the database for the
// remainder of the command
@ -52,6 +52,10 @@ func migrateDatabase(ctx context.Context, ui cli.Ui, dialect, u string) (func(),
ui.Error(fmt.Errorf("Error getting database state: %w", err).Error())
return unlock, 1
}
if requireFresh && st.InitializationStarted {
ui.Error(base.WrapAtLength("Database has already been initialized. Please use 'boundary database migrate'."))
return unlock, 1
}
if st.Dirty {
ui.Error(base.WrapAtLength("Database is in a bad state. Please revert back to the last known good state."))
return unlock, 1

@ -19,6 +19,7 @@ func TestMigrateDatabase(t *testing.T) {
cases := []struct {
name string
requireFresh bool
urlProvider func() string
expectedCode int
expectedOutput string
@ -37,6 +38,25 @@ func TestMigrateDatabase(t *testing.T) {
expectedCode: 0,
expectedOutput: "Migrations successfully run.\n",
},
{
name: "old_version_table_used",
urlProvider: func() string {
c, u, _, err := db.StartDbInDocker(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
dBase, err := sql.Open(dialect, u)
require.NoError(t, err)
createStmt := `create table if not exists schema_migrations (version bigint primary key, dirty boolean not null)`
_, err = dBase.Exec(createStmt)
require.NoError(t, err)
return u
},
expectedCode: 0,
expectedOutput: "Migrations successfully run.\n",
},
{
name: "bad_url",
urlProvider: func() string { return "badurl" },
@ -65,13 +85,77 @@ func TestMigrateDatabase(t *testing.T) {
expectedCode: 1,
expectedError: "Unable to capture a lock on the database.\n",
},
{
name: "basic_require_fresh",
requireFresh: true,
urlProvider: func() string {
c, u, _, err := db.StartDbInDocker(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
return u
},
expectedCode: 0,
expectedOutput: "Migrations successfully run.\n",
},
{
name: "old_version_table_used_require_fresh",
requireFresh: true,
urlProvider: func() string {
c, u, _, err := db.StartDbInDocker(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
dBase, err := sql.Open(dialect, u)
require.NoError(t, err)
createStmt := `create table if not exists schema_migrations (version bigint primary key, dirty boolean not null)`
_, err = dBase.Exec(createStmt)
require.NoError(t, err)
return u
},
expectedCode: 1,
expectedError: "Database has already been initialized. Please use 'boundary database\nmigrate'.\n",
},
{
name: "bad_url_require_fresh",
requireFresh: true,
urlProvider: func() string { return "badurl" },
expectedCode: 1,
expectedError: "Unable to connect to the database at \"badurl\"\n",
},
{
name: "cant_get_lock_require_fresh",
requireFresh: true,
urlProvider: func() string {
c, u, _, err := db.StartDbInDocker(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
dBase, err := sql.Open(dialect, u)
require.NoError(t, err)
man, err := schema.NewManager(ctx, dialect, dBase)
require.NoError(t, err)
// This is an advisory lock on the DB which is released when the DB session ends.
err = man.ExclusiveLock(ctx)
require.NoError(t, err)
return u
},
expectedCode: 1,
expectedError: "Unable to capture a lock on the database.\n",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
u := tc.urlProvider()
ui := cli.NewMockUi()
clean, errCode := migrateDatabase(ctx, ui, dialect, u)
clean, errCode := migrateDatabase(ctx, ui, dialect, u, tc.requireFresh)
clean()
assert.EqualValues(t, tc.expectedCode, errCode)
assert.Equal(t, tc.expectedOutput, ui.OutputWriter.String())

@ -260,7 +260,7 @@ func (c *InitCommand) Run(args []string) (retCode int) {
return 1
}
clean, errCode := migrateDatabase(c.Context, c.UI, dialect, migrationUrl)
clean, errCode := migrateDatabase(c.Context, c.UI, dialect, migrationUrl, true)
defer clean()
if errCode != 0 {
return errCode

@ -200,7 +200,7 @@ func (c *MigrateCommand) Run(args []string) (retCode int) {
return 1
}
clean, errCode := migrateDatabase(c.Context, c.UI, dialect, migrationUrl)
clean, errCode := migrateDatabase(c.Context, c.UI, dialect, migrationUrl, false)
defer clean()
if errCode != 0 {
return errCode

@ -367,6 +367,9 @@ func (c *Command) Run(args []string) int {
c.UI.Error(fmt.Errorf("Error checking schema state: %w", err).Error())
return 1
}
if !ckState.InitializationStarted {
c.UI.Error(base.WrapAtLength("The database has not been initialized. Please run 'boundary database init'."))
}
if ckState.Dirty {
c.UI.Error(base.WrapAtLength("Database is in a bad state. Please revert the database into the last known good state."))
return 1

@ -9,6 +9,7 @@ import (
"github.com/hashicorp/boundary/internal/db/schema/postgres"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/go-multierror"
)
// driver provides functionality to a database.
@ -29,7 +30,8 @@ type driver interface {
// executing Run.
Run(context.Context, io.Reader, int) error
// A version of -1 indicates no version is set.
CurrentState(context.Context) (int, bool, error)
CurrentState(context.Context) (ver int, everRan bool, dirty bool, err error)
EnsureVersionTable(ctx context.Context) error
}
// Manager provides a way to run operations and retrieve information regarding
@ -61,8 +63,7 @@ func NewManager(ctx context.Context, dialect string, db *sql.DB) (*Manager, erro
// State contains information regarding the current state of a boundary database's schema.
type State struct {
// InitializationStarted indicates if the current database has already been initialized
// (successfully or not) at least once.
// InitializationStarted indicates if the current database has been initialized previously.
InitializationStarted bool
// Dirty is set to true if the database failed in a previous migration/initialization.
Dirty bool
@ -74,19 +75,18 @@ type State struct {
// CurrentState provides the state of the boundary schema contained in the backing database.
func (b *Manager) CurrentState(ctx context.Context) (*State, error) {
const op = "schema.(Manager).CurrentState"
dbS := State{
BinarySchemaVersion: BinarySchemaVersion(b.dialect),
}
v, dirty, err := b.driver.CurrentState(ctx)
v, initialized, dirty, err := b.driver.CurrentState(ctx)
if err != nil {
return nil, err
return nil, errors.Wrap(err, op)
}
dbS.InitializationStarted = initialized
dbS.DatabaseSchemaVersion = v
dbS.Dirty = dirty
if v == nilVersion {
return &dbS, nil
}
dbS.InitializationStarted = true
return &dbS, nil
}
@ -147,7 +147,7 @@ func (b *Manager) RollForward(ctx context.Context) error {
b.driver.Unlock(ctx)
}()
curVersion, dirty, err := b.driver.CurrentState(ctx)
curVersion, _, dirty, err := b.driver.CurrentState(ctx)
if err != nil {
return errors.Wrap(err, op)
}
@ -156,7 +156,14 @@ func (b *Manager) RollForward(ctx context.Context) error {
return errors.New(errors.NotSpecificIntegrity, op, fmt.Sprintf("schema is dirty with version %d", curVersion))
}
return b.runMigrations(ctx, newStatementProvider(b.dialect, curVersion))
if err = b.runMigrations(ctx, newStatementProvider(b.dialect, curVersion)); err != nil {
return errors.Wrap(err, op)
}
return nil
}
type rollbacker interface {
Rollback() error
}
// runMigrations passes migration queries to a database driver and manages
@ -166,12 +173,21 @@ func (b *Manager) runMigrations(ctx context.Context, qp *statementProvider) erro
const op = "schema.(Manager).runMigrations"
if err := b.driver.StartRun(ctx); err != nil {
return err
return errors.Wrap(err, op)
}
if err := b.driver.EnsureVersionTable(ctx); err != nil {
return errors.Wrap(err, op)
}
for qp.Next() {
select {
case <-ctx.Done():
return errors.Wrap(ctx.Err(), op)
err := ctx.Err()
if d, ok := b.driver.(rollbacker); ok {
if rbErr := d.Rollback(); rbErr != nil {
err = multierror.Append(err, rbErr)
}
}
return errors.Wrap(err, op)
default:
// context is not done yet. Continue on to the next query to execute.
}
@ -180,7 +196,7 @@ func (b *Manager) runMigrations(ctx context.Context, qp *statementProvider) erro
}
}
if err := b.driver.CommitRun(); err != nil {
return err
return errors.Wrap(err, op)
}
return nil
}

@ -62,7 +62,8 @@ func TestCurrentState(t *testing.T) {
testDriver, err := postgres.New(ctx, d)
require.NoError(t, err)
require.NoError(t, testDriver.Run(ctx, strings.NewReader("SELECT 1"), 2))
require.NoError(t, testDriver.EnsureVersionTable(ctx))
require.NoError(t, testDriver.Run(ctx, strings.NewReader("select 1"), 2))
want = &State{
InitializationStarted: true,
@ -93,7 +94,7 @@ func TestRollForward(t *testing.T) {
_, err = postgres.New(ctx, d)
require.NoError(t, err)
// TODO: Extract out a way to mock the db to test failing rollforwards.
_, err = d.ExecContext(ctx, "TRUNCATE boundary_schema_version; INSERT INTO boundary_schema_version (Version, dirty) VALUES (2, true)")
_, err = d.ExecContext(ctx, "TRUNCATE boundary_schema_version; INSERT INTO boundary_schema_version (version, dirty) VALUES (2, true)")
assert.Error(t, m.RollForward(ctx))
}
@ -135,6 +136,26 @@ func TestRollForward_NotFromFresh(t *testing.T) {
assert.False(t, state.Dirty)
}
func TestRunMigration_canceledContext(t *testing.T) {
dialect := "postgres"
c, u, _, err := docker.StartDbInDocker(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := sql.Open(dialect, u)
require.NoError(t, err)
ctx := context.Background()
m, err := NewManager(ctx, dialect, d)
require.NoError(t, err)
// TODO: Find a way to test different parts of the runMigrations loop.
ctx, cancel := context.WithCancel(ctx)
cancel()
assert.Error(t, m.runMigrations(ctx, newStatementProvider(dialect, 0)))
}
func TestRollForward_BadSQL(t *testing.T) {
dialect := "postgres"
oState := migrationStates[dialect]

@ -77,11 +77,6 @@ func New(ctx context.Context, instance *sql.DB) (*Postgres, error) {
conn: conn,
db: instance,
}
if err := px.ensureVersionTable(ctx); err != nil {
return nil, errors.Wrap(err, op)
}
return px, nil
}
@ -89,7 +84,7 @@ func New(ctx context.Context, instance *sql.DB) (*Postgres, error) {
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
func (p *Postgres) TrySharedLock(ctx context.Context) error {
const op = "postgres.(Postgres).TrySharedLock"
const query = "SELECT pg_try_advisory_lock_shared($1)"
const query = "select pg_try_advisory_lock_shared($1)"
r := p.conn.QueryRowContext(ctx, query, schemaAccessLockId)
if r.Err() != nil {
return errors.Wrap(r.Err(), op)
@ -108,7 +103,7 @@ func (p *Postgres) TrySharedLock(ctx context.Context) error {
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
func (p *Postgres) TryLock(ctx context.Context) error {
const op = "postgres.(Postgres).TryLock"
const query = "SELECT pg_try_advisory_lock($1)"
const query = "select pg_try_advisory_lock($1)"
r := p.conn.QueryRowContext(ctx, query, schemaAccessLockId)
if r.Err() != nil {
return errors.Wrap(r.Err(), op)
@ -127,7 +122,7 @@ func (p *Postgres) TryLock(ctx context.Context) error {
// if we were unable to get the lock before the context cancels.
func (p *Postgres) Lock(ctx context.Context) error {
const op = "postgres.(Postgres).Lock"
const query = "SELECT pg_advisory_lock($1)"
const query = "select pg_advisory_lock($1)"
if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil {
return errors.Wrap(err, op)
}
@ -138,7 +133,7 @@ func (p *Postgres) Lock(ctx context.Context) error {
// release the lock before the context cancels.
func (p *Postgres) Unlock(ctx context.Context) error {
const op = "postgres.(Postgres).Unlock"
const query = `SELECT pg_advisory_unlock($1)`
const query = `select pg_advisory_unlock($1)`
if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil {
return errors.Wrap(err, op)
}
@ -149,13 +144,29 @@ func (p *Postgres) Unlock(ctx context.Context) error {
// release the lock before the context cancels.
func (p *Postgres) UnlockShared(ctx context.Context) error {
const op = "postgres.(Postgres).UnlockShared"
query := `SELECT pg_advisory_unlock_shared($1)`
query := `select pg_advisory_unlock_shared($1)`
if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil {
return errors.Wrap(err, op)
}
return nil
}
// Rollback rolls back the outstanding transaction.
// Calling Rollback when there is not an outstanding transaction is an error.
func (p *Postgres) Rollback() error {
const op = "postgres.(Postgres).Rollback"
defer func() {
p.tx = nil
}()
if p.tx == nil {
return errors.New(errors.MigrationIntegrity, op, "no pending transaction")
}
if err := p.tx.Rollback(); err != nil {
return errors.Wrap(err, op)
}
return nil
}
// StartRun starts a transaction that all subsequent calls to Run will use.
func (p *Postgres) StartRun(ctx context.Context) error {
tx, err := p.conn.BeginTx(ctx, nil)
@ -177,19 +188,21 @@ func (p *Postgres) CommitRun() error {
}
if err := p.tx.Commit(); err != nil {
if errRollback := p.tx.Rollback(); errRollback != nil {
return multierror.Append(err, errRollback)
err = multierror.Append(err, errRollback)
}
return errors.Wrap(err, op)
}
return nil
}
type execContexter interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
// Executes the sql provided in the passed in io.Reader. The contents of the reader must
// Run executes the sql provided in the passed in io.Reader. The contents of the reader must
// fit in memory as the full content is read into a string before being passed to the
// backing database.
// backing database. EnsureVersionTable should be ran prior to this call.
func (p *Postgres) Run(ctx context.Context, migration io.Reader, version int) error {
const op = "postgres.(Postgres).Run"
migr, err := ioutil.ReadAll(migration)
@ -293,7 +306,7 @@ func (p *Postgres) setVersion(ctx context.Context, version int, dirty bool) erro
return tx.Rollback()
}
query := `TRUNCATE ` + pq.QuoteIdentifier(defaultMigrationsTable)
query := `truncate ` + pq.QuoteIdentifier(defaultMigrationsTable)
if _, err := tx.ExecContext(ctx, query); err != nil {
if errRollback := rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
@ -305,8 +318,8 @@ func (p *Postgres) setVersion(ctx context.Context, version int, dirty bool) erro
// empty schema Version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == nilVersion && dirty) {
query = `INSERT INTO ` + pq.QuoteIdentifier(defaultMigrationsTable) +
` (Version, dirty) VALUES ($1, $2)`
query = `insert into ` + pq.QuoteIdentifier(defaultMigrationsTable) +
` (version, dirty) values ($1, $2)`
if _, err := tx.ExecContext(ctx, query, version, dirty); err != nil {
if errRollback := rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
@ -323,33 +336,58 @@ func (p *Postgres) setVersion(ctx context.Context, version int, dirty bool) erro
return nil
}
// CurrentState returns the version, if the database is currently in a dirty state, and any error.
// A version value of -1 indicates no version is set.
func (p *Postgres) CurrentState(ctx context.Context) (version int, dirty bool, err error) {
// CurrentState returns the version, if the database was ever initialized
// previously, if it is currently in a dirty state, and any error. A version
// value of -1 indicates no version is set.
func (p *Postgres) CurrentState(ctx context.Context) (version int, previouslyRan, dirty bool, err error) {
const op = "postgres.(Postgres).CurrentState"
query := `SELECT Version, dirty FROM ` + pq.QuoteIdentifier(defaultMigrationsTable) + ` LIMIT 1`
err = p.conn.QueryRowContext(ctx, query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return nilVersion, false, nil
case err != nil:
if e, ok := err.(*pq.Error); ok {
if e.Code.Name() == "undefined_table" {
return nilVersion, false, nil
}
}
return 0, false, errors.Wrap(err, op)
default:
return version, dirty, nil
version = nilVersion
previouslyRan, dirty = false, false
tableQuery := `select table_name from information_schema.tables where table_schema=(select current_schema()) and table_name in ('schema_migrations', '` + defaultMigrationsTable + `')`
tableResult, err := p.conn.QueryContext(ctx, tableQuery)
if err != nil {
return nilVersion, previouslyRan, dirty, errors.Wrap(err, op)
}
defer tableResult.Close()
if !tableResult.Next() {
// No version table found
return nilVersion, previouslyRan, dirty, nil
}
tableName := defaultMigrationsTable
if err := tableResult.Scan(&tableName); err != nil {
return nilVersion, previouslyRan, dirty, errors.Wrap(err, op)
}
previouslyRan = true
if tableResult.Next() {
return nilVersion, previouslyRan, dirty, errors.New(errors.MigrationIntegrity, op, "both old and new migration tables exist")
}
query := `select version, dirty from ` + pq.QuoteIdentifier(tableName)
results, err := p.conn.QueryContext(ctx, query)
if err != nil {
return nilVersion, previouslyRan, dirty, errors.Wrap(err, op)
}
defer results.Close()
if !results.Next() {
// no version recorded
return nilVersion, previouslyRan, dirty, nil
}
if err := results.Scan(&version, &dirty); err != nil {
return nilVersion, previouslyRan, dirty, errors.Wrap(err, op)
}
if results.Next() {
return nilVersion, previouslyRan, dirty, errors.New(errors.MigrationIntegrity, op, "to many versions in version table")
}
return version, previouslyRan, dirty, nil
}
func (p *Postgres) drop(ctx context.Context) (err error) {
const op = "postgres.(Postgres).drop"
// select all tables in current schema
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
query := `select table_name from information_schema.tables where table_schema=(select current_schema()) and table_type='BASE TABLE'`
tables, err := p.conn.QueryContext(ctx, query)
if err != nil {
return errors.Wrap(err, op)
@ -379,7 +417,7 @@ func (p *Postgres) drop(ctx context.Context) (err error) {
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
query = `drop table if exists ` + pq.QuoteIdentifier(t) + ` cascade`
if _, err := p.conn.ExecContext(ctx, query); err != nil {
return errors.Wrap(err, op)
}
@ -389,23 +427,45 @@ func (p *Postgres) drop(ctx context.Context) (err error) {
return nil
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the postgres type.
func (p *Postgres) ensureVersionTable(ctx context.Context) (err error) {
const op = "postgres.(Postgres).ensureVersionTable"
if err = p.TryLock(ctx); err != nil {
// EnsureVersionTable checks if versions table exists and, if not, creates it.
func (p *Postgres) EnsureVersionTable(ctx context.Context) (err error) {
const op = "postgres.(Postgres).EnsureVersionTable"
var extr execContexter = p.conn
rollback := func() error { return nil }
if p.tx != nil {
extr = p.tx
rollback = func() error {
defer func() { p.tx = nil }()
return p.tx.Rollback()
}
}
query := `select exists (select 1 from information_schema.tables where table_schema=(select current_schema()) and table_name = '` + defaultMigrationsTable + `');`
exists := false
if err := extr.QueryRowContext(ctx, query).Scan(&exists); err != nil {
if wpErr := rollback(); wpErr != nil {
err = multierror.Append(err, wpErr)
}
return errors.Wrap(err, op)
}
if exists {
return nil
}
defer func() {
if e := p.Unlock(ctx); e != nil {
err = multierror.Append(err, e)
updateQuery := `alter table if exists schema_migrations rename to ` + defaultMigrationsTable + `;`
if _, err = extr.ExecContext(ctx, updateQuery); err != nil {
if wpErr := rollback(); wpErr != nil {
err = multierror.Append(err, wpErr)
}
}()
return errors.Wrap(err, op)
}
query := `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(defaultMigrationsTable) + ` (Version bigint primary key, dirty boolean not null)`
if _, err = p.conn.ExecContext(ctx, query); err != nil {
createStmt := `create table if not exists ` + pq.QuoteIdentifier(defaultMigrationsTable) + ` (version bigint primary key, dirty boolean not null)`
if _, err = extr.ExecContext(ctx, createStmt); err != nil {
if wpErr := rollback(); wpErr != nil {
err = multierror.Append(err, wpErr)
}
return errors.Wrap(err, op)
}

@ -117,7 +117,7 @@ func TestDbStuff(t *testing.T) {
})
}
func TestVersion_NoVersionTable(t *testing.T) {
func TestCurrentState_NoVersionTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
@ -138,8 +138,43 @@ func TestVersion_NoVersionTable(t *testing.T) {
// Drop the version table so calls to CurrentState don't rely on that
d.drop(ctx)
v, dirt, err := d.CurrentState(ctx)
v, alreadyRan, dirt, err := d.CurrentState(ctx)
assert.NoError(t, err)
assert.False(t, alreadyRan)
assert.Equal(t, v, nilVersion)
assert.False(t, dirt)
})
}
func TestCurrentState_ToManyTables(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
d, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.close(t); err != nil {
t.Error(err)
}
}()
// Create the most recent table
d.EnsureVersionTable(ctx)
// Create the legacy version of the table.
oldTableCreate := `create table if not exists schema_migrations (version bigint primary key, dirty boolean not null)`
_, err = d.conn.ExecContext(ctx, oldTableCreate)
require.NoError(t, err)
v, alreadyRan, dirt, err := d.CurrentState(ctx)
assert.Error(t, err)
assert.True(t, alreadyRan)
assert.Equal(t, v, nilVersion)
assert.False(t, dirt)
})
@ -163,6 +198,10 @@ func TestMultiStatement(t *testing.T) {
t.Error(err)
}
}()
if err := d.EnsureVersionTable(ctx); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);"), 2); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
@ -197,12 +236,33 @@ func TestTransaction(t *testing.T) {
}
}()
v, alreadyRan, dirty, err := d.CurrentState(ctx)
assert.NoError(t, err)
assert.False(t, alreadyRan)
assert.False(t, dirty)
assert.Equal(t, -1, v)
// Fail the initial setup of the db.
assert.NoError(t, d.StartRun(ctx))
assert.NoError(t, d.EnsureVersionTable(ctx))
assert.Error(t, d.Run(ctx, strings.NewReader("SELECT 1 from nonExistantTable"), 3))
assert.Error(t, d.CommitRun())
v, alreadyRan, dirty, err = d.CurrentState(ctx)
assert.NoError(t, err)
assert.False(t, alreadyRan)
assert.False(t, dirty)
assert.Equal(t, -1, v)
assert.NoError(t, d.StartRun(ctx))
assert.NoError(t, d.EnsureVersionTable(ctx))
assert.NoError(t, d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text);"), 2))
assert.NoError(t, d.Run(ctx, strings.NewReader("SELECT 1"), 3))
assert.NoError(t, d.Run(ctx, strings.NewReader("SELECT 1;"), 3))
assert.NoError(t, d.CommitRun())
v, dirty, err := d.CurrentState(ctx)
v, alreadyRan, dirty, err = d.CurrentState(ctx)
assert.NoError(t, err)
assert.True(t, alreadyRan)
assert.False(t, dirty)
assert.Equal(t, 3, v)
@ -210,8 +270,10 @@ func TestTransaction(t *testing.T) {
assert.NoError(t, d.Run(ctx, strings.NewReader("CREATE TABLE bar (bar text);"), 20))
assert.Error(t, d.Run(ctx, strings.NewReader("SELECT 1 FROM NonExistingTable"), 30))
assert.Error(t, d.CommitRun())
v, dirty, err = d.CurrentState(ctx)
v, alreadyRan, dirty, err = d.CurrentState(ctx)
assert.NoError(t, err)
assert.True(t, alreadyRan)
assert.False(t, dirty)
assert.Equal(t, 3, v)
})
@ -221,70 +283,91 @@ func TestWithSchema(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
addr := pgConnectionString(ip, port)
d, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer func() {
if err := d.close(t); err != nil {
t.Fatal(err)
}
}()
require.NoError(t, d.EnsureVersionTable(ctx))
// create foobar schema
if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres"), 1); err != nil {
t.Fatal(err)
}
require.NoError(t, d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres"), 1))
// re-connect using that schema
d2, err := open(t, ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foobar",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer func() {
if err := d2.close(t); err != nil {
t.Fatal(err)
}
}()
version, _, err := d2.CurrentState(ctx)
version, alreadyRan, _, err := d2.CurrentState(ctx)
require.NoError(t, err)
require.Equal(t, nilVersion, version)
assert.False(t, alreadyRan)
// now update CurrentState and compare
require.NoError(t, d2.EnsureVersionTable(ctx))
require.NoError(t, d2.setVersion(ctx, 2, false))
version, alreadyRan, _, err = d2.CurrentState(ctx)
require.NoError(t, err)
require.Equal(t, 2, version)
assert.True(t, alreadyRan)
// meanwhile, the public schema still has the other CurrentState
version, alreadyRan, _, err = d.CurrentState(ctx)
require.NoError(t, err)
require.Equal(t, 1, version)
assert.True(t, alreadyRan)
})
}
func TestPostgres_Lock(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
if version != nilVersion {
t.Fatal("expected NilVersion")
}
// now update CurrentState and compare
if err := d2.setVersion(ctx, 2, false); err != nil {
addr := pgConnectionString(ip, port)
ps, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
}
version, _, err = d2.CurrentState(ctx)
test(t, ps, []byte("SELECT 1"))
err = ps.Lock(ctx)
if err != nil {
t.Fatal(err)
}
if version != 2 {
t.Fatal("expected Version 2")
err = ps.Unlock(ctx)
if err != nil {
t.Fatal(err)
}
// meanwhile, the public schema still has the other CurrentState
version, _, err = d.CurrentState(ctx)
err = ps.Lock(ctx)
if err != nil {
t.Fatal(err)
}
if version != 1 {
t.Fatal("expected Version 2")
err = ps.Unlock(ctx)
if err != nil {
t.Fatal(err)
}
})
}
func TestPostgres_Lock(t *testing.T) {
func TestEnsureTable_Fresh(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
@ -293,32 +376,118 @@ func TestPostgres_Lock(t *testing.T) {
}
addr := pgConnectionString(ip, port)
ps, err := open(t, ctx, addr)
p, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
require.NoError(t, err)
}
t.Cleanup(func() {
require.NoError(t, p.close(t))
})
test(t, ps, []byte("SELECT 1"))
tableCreated := false
query := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = '" + defaultMigrationsTable + "')"
assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableCreated))
assert.False(t, tableCreated)
err = ps.Lock(ctx)
assert.NoError(t, p.EnsureVersionTable(ctx))
assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableCreated))
assert.True(t, tableCreated)
})
}
func TestEnsureTable_ExistingTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
err = ps.Unlock(ctx)
addr := pgConnectionString(ip, port)
p, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
require.NoError(t, err)
}
t.Cleanup(func() {
require.NoError(t, p.close(t))
})
assert.NoError(t, p.EnsureVersionTable(ctx))
err = ps.Lock(ctx)
oldTableCreate := `CREATE TABLE IF NOT EXISTS schema_migrations (version bigint primary key, dirty boolean not null)`
_, err = p.db.ExecContext(ctx, oldTableCreate)
assert.NoError(t, err)
assert.NoError(t, p.EnsureVersionTable(ctx))
})
}
func TestEnsureTable_OldTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
err = ps.Unlock(ctx)
addr := pgConnectionString(ip, port)
p, err := open(t, ctx, addr)
if err != nil {
require.NoError(t, err)
}
t.Cleanup(func() {
require.NoError(t, p.close(t))
})
oldTableCreate := `CREATE TABLE IF NOT EXISTS schema_migrations (version bigint primary key, dirty boolean not null)`
_, err = p.db.ExecContext(ctx, oldTableCreate)
assert.NoError(t, err)
tableExists := false
oldTableCheck := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = 'schema_migrations')"
assert.NoError(t, p.db.QueryRowContext(ctx, oldTableCheck).Scan(&tableExists))
assert.True(t, tableExists)
query := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = '" + defaultMigrationsTable + "')"
assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableExists))
assert.False(t, tableExists)
assert.NoError(t, p.EnsureVersionTable(ctx))
assert.NoError(t, p.db.QueryRowContext(ctx, oldTableCheck).Scan(&tableExists))
assert.False(t, tableExists)
assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableExists))
assert.True(t, tableExists)
})
}
func TestRollback(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p, err := open(t, ctx, addr)
if err != nil {
require.NoError(t, err)
}
t.Cleanup(func() {
require.NoError(t, p.close(t))
})
assert.NoError(t, p.StartRun(ctx))
assert.NoError(t, p.EnsureVersionTable(ctx))
assert.NoError(t, p.Run(ctx, bytes.NewReader([]byte("create table if not exists foo (foo text)")), 2))
var exists bool
query := "select exists (select 1 from information_schema.tables where table_name = 'foo' and table_schema = (select current_schema()))"
assert.NoError(t, p.conn.QueryRowContext(context.Background(), query).Scan(&exists))
assert.True(t, exists)
assert.NoError(t, p.Rollback())
assert.NoError(t, p.conn.QueryRowContext(context.Background(), query).Scan(&exists))
assert.False(t, exists)
})
}

@ -32,11 +32,11 @@ import (
"bytes"
"context"
"database/sql"
"io"
"testing"
"time"
"github.com/golang-migrate/migrate/v4/database"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -45,24 +45,21 @@ func test(t *testing.T, d *Postgres, migration []byte) {
if migration == nil {
t.Fatal("test must provide migration reader")
}
ctx := context.Background()
v, alreadyRan, dirty, err := d.CurrentState(ctx)
require.NoError(t, err)
assert.False(t, alreadyRan)
assert.False(t, dirty)
assert.Equal(t, nilVersion, v)
testNilVersion(t, d) // test first
testLockAndUnlock(t, d)
testRun(t, d, bytes.NewReader(migration))
assert.NoError(t, d.EnsureVersionTable(ctx))
assert.NoError(t, d.Run(ctx, bytes.NewReader(migration), 1))
testSetVersion(t, d) // also tests CurrentState()
// drop breaks the driver, so test it last.
testDrop(t, d)
}
func testNilVersion(t *testing.T, d *Postgres) {
ctx := context.Background()
v, _, err := d.CurrentState(ctx)
if err != nil {
t.Fatal(err)
}
if v != database.NilVersion {
t.Fatalf("Version: expected Version to be NilVersion (-1), got %v", v)
}
assert.NoError(t, d.drop(ctx))
}
func testLockAndUnlock(t *testing.T, d *Postgres) {
@ -71,47 +68,20 @@ func testLockAndUnlock(t *testing.T, d *Postgres) {
ctx, _ = context.WithTimeout(ctx, 15*time.Second)
// locking twice is ok, no error
if err := d.Lock(ctx); err != nil {
t.Fatalf("got error, expected none: %v", err)
}
if err := d.Lock(ctx); err != nil {
t.Fatalf("got error, expected none: %v", err)
}
require.NoError(t, d.Lock(ctx))
assert.NoError(t, d.Lock(ctx))
// Unlock
if err := d.Unlock(ctx); err != nil {
t.Fatalf("error unlocking: %v", err)
}
assert.NoError(t, d.Unlock(ctx))
// try to Lock
if err := d.Lock(ctx); err != nil {
t.Fatalf("got error, expected none: %v", err)
}
if err := d.Unlock(ctx); err != nil {
t.Fatalf("got error, expected none: %v", err)
}
}
func testRun(t *testing.T, d *Postgres, migration io.Reader) {
ctx := context.Background()
if migration == nil {
t.Fatal("migration can't be nil")
}
if err := d.Run(ctx, migration, 1); err != nil {
t.Fatal(err)
}
}
func testDrop(t *testing.T, d *Postgres) {
ctx := context.Background()
if err := d.drop(ctx); err != nil {
t.Fatal(err)
}
assert.NoError(t, d.Lock(ctx))
assert.NoError(t, d.Unlock(ctx))
}
func testSetVersion(t *testing.T, d *Postgres) {
ctx := context.Background()
require.NoError(t, d.EnsureVersionTable(ctx))
// nolint:maligned
testCases := []struct {
name string
@ -132,11 +102,12 @@ func testSetVersion(t *testing.T, d *Postgres) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// d.EnsureVersionTable(ctx)
err := d.setVersion(ctx, tc.version, tc.dirty)
if err != tc.expectedErr {
t.Fatal("Got unexpected error:", err, "!=", tc.expectedErr)
}
v, dirty, readErr := d.CurrentState(ctx)
v, _, dirty, readErr := d.CurrentState(ctx)
if readErr != tc.expectedReadErr {
t.Fatal("Got unexpected error:", readErr, "!=", tc.expectedReadErr)
}

Loading…
Cancel
Save