diff --git a/internal/cmd/commands/database/funcs.go b/internal/cmd/commands/database/funcs.go index 96b152ffe5..c3661cb92a 100644 --- a/internal/cmd/commands/database/funcs.go +++ b/internal/cmd/commands/database/funcs.go @@ -15,7 +15,7 @@ import ( // It owns the reporting to the UI any errors. // Returns a cleanup function which must be called even if an error is returned and // an error code where a non-zero value indicates an error happened. -func migrateDatabase(ctx context.Context, ui cli.Ui, dialect, u string) (func(), int) { +func migrateDatabase(ctx context.Context, ui cli.Ui, dialect, u string, requireFresh bool) (func(), int) { noop := func() {} // This database is used to keep an exclusive lock on the database for the // remainder of the command @@ -52,6 +52,10 @@ func migrateDatabase(ctx context.Context, ui cli.Ui, dialect, u string) (func(), ui.Error(fmt.Errorf("Error getting database state: %w", err).Error()) return unlock, 1 } + if requireFresh && st.InitializationStarted { + ui.Error(base.WrapAtLength("Database has already been initialized. Please use 'boundary database migrate'.")) + return unlock, 1 + } if st.Dirty { ui.Error(base.WrapAtLength("Database is in a bad state. Please revert back to the last known good state.")) return unlock, 1 diff --git a/internal/cmd/commands/database/funcs_test.go b/internal/cmd/commands/database/funcs_test.go index 2a67139deb..0cc4d524a3 100644 --- a/internal/cmd/commands/database/funcs_test.go +++ b/internal/cmd/commands/database/funcs_test.go @@ -19,6 +19,7 @@ func TestMigrateDatabase(t *testing.T) { cases := []struct { name string + requireFresh bool urlProvider func() string expectedCode int expectedOutput string @@ -37,6 +38,25 @@ func TestMigrateDatabase(t *testing.T) { expectedCode: 0, expectedOutput: "Migrations successfully run.\n", }, + { + name: "old_version_table_used", + urlProvider: func() string { + c, u, _, err := db.StartDbInDocker(dialect) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + dBase, err := sql.Open(dialect, u) + require.NoError(t, err) + + createStmt := `create table if not exists schema_migrations (version bigint primary key, dirty boolean not null)` + _, err = dBase.Exec(createStmt) + require.NoError(t, err) + return u + }, + expectedCode: 0, + expectedOutput: "Migrations successfully run.\n", + }, { name: "bad_url", urlProvider: func() string { return "badurl" }, @@ -65,13 +85,77 @@ func TestMigrateDatabase(t *testing.T) { expectedCode: 1, expectedError: "Unable to capture a lock on the database.\n", }, + { + name: "basic_require_fresh", + requireFresh: true, + urlProvider: func() string { + c, u, _, err := db.StartDbInDocker(dialect) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + return u + }, + expectedCode: 0, + expectedOutput: "Migrations successfully run.\n", + }, + { + name: "old_version_table_used_require_fresh", + requireFresh: true, + urlProvider: func() string { + c, u, _, err := db.StartDbInDocker(dialect) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + dBase, err := sql.Open(dialect, u) + require.NoError(t, err) + + createStmt := `create table if not exists schema_migrations (version bigint primary key, dirty boolean not null)` + _, err = dBase.Exec(createStmt) + require.NoError(t, err) + return u + }, + expectedCode: 1, + expectedError: "Database has already been initialized. Please use 'boundary database\nmigrate'.\n", + }, + { + name: "bad_url_require_fresh", + requireFresh: true, + urlProvider: func() string { return "badurl" }, + expectedCode: 1, + expectedError: "Unable to connect to the database at \"badurl\"\n", + }, + { + name: "cant_get_lock_require_fresh", + requireFresh: true, + urlProvider: func() string { + c, u, _, err := db.StartDbInDocker(dialect) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + dBase, err := sql.Open(dialect, u) + require.NoError(t, err) + + man, err := schema.NewManager(ctx, dialect, dBase) + require.NoError(t, err) + // This is an advisory lock on the DB which is released when the DB session ends. + err = man.ExclusiveLock(ctx) + require.NoError(t, err) + + return u + }, + expectedCode: 1, + expectedError: "Unable to capture a lock on the database.\n", + }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { u := tc.urlProvider() ui := cli.NewMockUi() - clean, errCode := migrateDatabase(ctx, ui, dialect, u) + clean, errCode := migrateDatabase(ctx, ui, dialect, u, tc.requireFresh) clean() assert.EqualValues(t, tc.expectedCode, errCode) assert.Equal(t, tc.expectedOutput, ui.OutputWriter.String()) diff --git a/internal/cmd/commands/database/init.go b/internal/cmd/commands/database/init.go index 50be79ea39..6a55d148cb 100644 --- a/internal/cmd/commands/database/init.go +++ b/internal/cmd/commands/database/init.go @@ -260,7 +260,7 @@ func (c *InitCommand) Run(args []string) (retCode int) { return 1 } - clean, errCode := migrateDatabase(c.Context, c.UI, dialect, migrationUrl) + clean, errCode := migrateDatabase(c.Context, c.UI, dialect, migrationUrl, true) defer clean() if errCode != 0 { return errCode diff --git a/internal/cmd/commands/database/migrate.go b/internal/cmd/commands/database/migrate.go index 9e20927cdd..a9ba0cc959 100644 --- a/internal/cmd/commands/database/migrate.go +++ b/internal/cmd/commands/database/migrate.go @@ -200,7 +200,7 @@ func (c *MigrateCommand) Run(args []string) (retCode int) { return 1 } - clean, errCode := migrateDatabase(c.Context, c.UI, dialect, migrationUrl) + clean, errCode := migrateDatabase(c.Context, c.UI, dialect, migrationUrl, false) defer clean() if errCode != 0 { return errCode diff --git a/internal/cmd/commands/server/server.go b/internal/cmd/commands/server/server.go index b8019fc618..ce4d3a1e60 100644 --- a/internal/cmd/commands/server/server.go +++ b/internal/cmd/commands/server/server.go @@ -367,6 +367,9 @@ func (c *Command) Run(args []string) int { c.UI.Error(fmt.Errorf("Error checking schema state: %w", err).Error()) return 1 } + if !ckState.InitializationStarted { + c.UI.Error(base.WrapAtLength("The database has not been initialized. Please run 'boundary database init'.")) + } if ckState.Dirty { c.UI.Error(base.WrapAtLength("Database is in a bad state. Please revert the database into the last known good state.")) return 1 diff --git a/internal/db/schema/manager.go b/internal/db/schema/manager.go index 12f81d71c0..71dd2aa868 100644 --- a/internal/db/schema/manager.go +++ b/internal/db/schema/manager.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/boundary/internal/db/schema/postgres" "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/go-multierror" ) // driver provides functionality to a database. @@ -29,7 +30,8 @@ type driver interface { // executing Run. Run(context.Context, io.Reader, int) error // A version of -1 indicates no version is set. - CurrentState(context.Context) (int, bool, error) + CurrentState(context.Context) (ver int, everRan bool, dirty bool, err error) + EnsureVersionTable(ctx context.Context) error } // Manager provides a way to run operations and retrieve information regarding @@ -61,8 +63,7 @@ func NewManager(ctx context.Context, dialect string, db *sql.DB) (*Manager, erro // State contains information regarding the current state of a boundary database's schema. type State struct { - // InitializationStarted indicates if the current database has already been initialized - // (successfully or not) at least once. + // InitializationStarted indicates if the current database has been initialized previously. InitializationStarted bool // Dirty is set to true if the database failed in a previous migration/initialization. Dirty bool @@ -74,19 +75,18 @@ type State struct { // CurrentState provides the state of the boundary schema contained in the backing database. func (b *Manager) CurrentState(ctx context.Context) (*State, error) { + const op = "schema.(Manager).CurrentState" dbS := State{ BinarySchemaVersion: BinarySchemaVersion(b.dialect), } - v, dirty, err := b.driver.CurrentState(ctx) + + v, initialized, dirty, err := b.driver.CurrentState(ctx) if err != nil { - return nil, err + return nil, errors.Wrap(err, op) } + dbS.InitializationStarted = initialized dbS.DatabaseSchemaVersion = v dbS.Dirty = dirty - if v == nilVersion { - return &dbS, nil - } - dbS.InitializationStarted = true return &dbS, nil } @@ -147,7 +147,7 @@ func (b *Manager) RollForward(ctx context.Context) error { b.driver.Unlock(ctx) }() - curVersion, dirty, err := b.driver.CurrentState(ctx) + curVersion, _, dirty, err := b.driver.CurrentState(ctx) if err != nil { return errors.Wrap(err, op) } @@ -156,7 +156,14 @@ func (b *Manager) RollForward(ctx context.Context) error { return errors.New(errors.NotSpecificIntegrity, op, fmt.Sprintf("schema is dirty with version %d", curVersion)) } - return b.runMigrations(ctx, newStatementProvider(b.dialect, curVersion)) + if err = b.runMigrations(ctx, newStatementProvider(b.dialect, curVersion)); err != nil { + return errors.Wrap(err, op) + } + return nil +} + +type rollbacker interface { + Rollback() error } // runMigrations passes migration queries to a database driver and manages @@ -166,12 +173,21 @@ func (b *Manager) runMigrations(ctx context.Context, qp *statementProvider) erro const op = "schema.(Manager).runMigrations" if err := b.driver.StartRun(ctx); err != nil { - return err + return errors.Wrap(err, op) + } + if err := b.driver.EnsureVersionTable(ctx); err != nil { + return errors.Wrap(err, op) } for qp.Next() { select { case <-ctx.Done(): - return errors.Wrap(ctx.Err(), op) + err := ctx.Err() + if d, ok := b.driver.(rollbacker); ok { + if rbErr := d.Rollback(); rbErr != nil { + err = multierror.Append(err, rbErr) + } + } + return errors.Wrap(err, op) default: // context is not done yet. Continue on to the next query to execute. } @@ -180,7 +196,7 @@ func (b *Manager) runMigrations(ctx context.Context, qp *statementProvider) erro } } if err := b.driver.CommitRun(); err != nil { - return err + return errors.Wrap(err, op) } return nil } diff --git a/internal/db/schema/manager_test.go b/internal/db/schema/manager_test.go index 760d0c4c16..08f8124cad 100644 --- a/internal/db/schema/manager_test.go +++ b/internal/db/schema/manager_test.go @@ -62,7 +62,8 @@ func TestCurrentState(t *testing.T) { testDriver, err := postgres.New(ctx, d) require.NoError(t, err) - require.NoError(t, testDriver.Run(ctx, strings.NewReader("SELECT 1"), 2)) + require.NoError(t, testDriver.EnsureVersionTable(ctx)) + require.NoError(t, testDriver.Run(ctx, strings.NewReader("select 1"), 2)) want = &State{ InitializationStarted: true, @@ -93,7 +94,7 @@ func TestRollForward(t *testing.T) { _, err = postgres.New(ctx, d) require.NoError(t, err) // TODO: Extract out a way to mock the db to test failing rollforwards. - _, err = d.ExecContext(ctx, "TRUNCATE boundary_schema_version; INSERT INTO boundary_schema_version (Version, dirty) VALUES (2, true)") + _, err = d.ExecContext(ctx, "TRUNCATE boundary_schema_version; INSERT INTO boundary_schema_version (version, dirty) VALUES (2, true)") assert.Error(t, m.RollForward(ctx)) } @@ -135,6 +136,26 @@ func TestRollForward_NotFromFresh(t *testing.T) { assert.False(t, state.Dirty) } +func TestRunMigration_canceledContext(t *testing.T) { + dialect := "postgres" + c, u, _, err := docker.StartDbInDocker(dialect) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + d, err := sql.Open(dialect, u) + require.NoError(t, err) + + ctx := context.Background() + m, err := NewManager(ctx, dialect, d) + require.NoError(t, err) + + // TODO: Find a way to test different parts of the runMigrations loop. + ctx, cancel := context.WithCancel(ctx) + cancel() + assert.Error(t, m.runMigrations(ctx, newStatementProvider(dialect, 0))) +} + func TestRollForward_BadSQL(t *testing.T) { dialect := "postgres" oState := migrationStates[dialect] diff --git a/internal/db/schema/postgres/postgres.go b/internal/db/schema/postgres/postgres.go index feccbc3e4d..c3020a1f1e 100644 --- a/internal/db/schema/postgres/postgres.go +++ b/internal/db/schema/postgres/postgres.go @@ -77,11 +77,6 @@ func New(ctx context.Context, instance *sql.DB) (*Postgres, error) { conn: conn, db: instance, } - - if err := px.ensureVersionTable(ctx); err != nil { - return nil, errors.Wrap(err, op) - } - return px, nil } @@ -89,7 +84,7 @@ func New(ctx context.Context, instance *sql.DB) (*Postgres, error) { // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS func (p *Postgres) TrySharedLock(ctx context.Context) error { const op = "postgres.(Postgres).TrySharedLock" - const query = "SELECT pg_try_advisory_lock_shared($1)" + const query = "select pg_try_advisory_lock_shared($1)" r := p.conn.QueryRowContext(ctx, query, schemaAccessLockId) if r.Err() != nil { return errors.Wrap(r.Err(), op) @@ -108,7 +103,7 @@ func (p *Postgres) TrySharedLock(ctx context.Context) error { // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS func (p *Postgres) TryLock(ctx context.Context) error { const op = "postgres.(Postgres).TryLock" - const query = "SELECT pg_try_advisory_lock($1)" + const query = "select pg_try_advisory_lock($1)" r := p.conn.QueryRowContext(ctx, query, schemaAccessLockId) if r.Err() != nil { return errors.Wrap(r.Err(), op) @@ -127,7 +122,7 @@ func (p *Postgres) TryLock(ctx context.Context) error { // if we were unable to get the lock before the context cancels. func (p *Postgres) Lock(ctx context.Context) error { const op = "postgres.(Postgres).Lock" - const query = "SELECT pg_advisory_lock($1)" + const query = "select pg_advisory_lock($1)" if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil { return errors.Wrap(err, op) } @@ -138,7 +133,7 @@ func (p *Postgres) Lock(ctx context.Context) error { // release the lock before the context cancels. func (p *Postgres) Unlock(ctx context.Context) error { const op = "postgres.(Postgres).Unlock" - const query = `SELECT pg_advisory_unlock($1)` + const query = `select pg_advisory_unlock($1)` if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil { return errors.Wrap(err, op) } @@ -149,13 +144,29 @@ func (p *Postgres) Unlock(ctx context.Context) error { // release the lock before the context cancels. func (p *Postgres) UnlockShared(ctx context.Context) error { const op = "postgres.(Postgres).UnlockShared" - query := `SELECT pg_advisory_unlock_shared($1)` + query := `select pg_advisory_unlock_shared($1)` if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil { return errors.Wrap(err, op) } return nil } +// Rollback rolls back the outstanding transaction. +// Calling Rollback when there is not an outstanding transaction is an error. +func (p *Postgres) Rollback() error { + const op = "postgres.(Postgres).Rollback" + defer func() { + p.tx = nil + }() + if p.tx == nil { + return errors.New(errors.MigrationIntegrity, op, "no pending transaction") + } + if err := p.tx.Rollback(); err != nil { + return errors.Wrap(err, op) + } + return nil +} + // StartRun starts a transaction that all subsequent calls to Run will use. func (p *Postgres) StartRun(ctx context.Context) error { tx, err := p.conn.BeginTx(ctx, nil) @@ -177,19 +188,21 @@ func (p *Postgres) CommitRun() error { } if err := p.tx.Commit(); err != nil { if errRollback := p.tx.Rollback(); errRollback != nil { - return multierror.Append(err, errRollback) + err = multierror.Append(err, errRollback) } + return errors.Wrap(err, op) } return nil } type execContexter interface { ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } -// Executes the sql provided in the passed in io.Reader. The contents of the reader must +// Run executes the sql provided in the passed in io.Reader. The contents of the reader must // fit in memory as the full content is read into a string before being passed to the -// backing database. +// backing database. EnsureVersionTable should be ran prior to this call. func (p *Postgres) Run(ctx context.Context, migration io.Reader, version int) error { const op = "postgres.(Postgres).Run" migr, err := ioutil.ReadAll(migration) @@ -293,7 +306,7 @@ func (p *Postgres) setVersion(ctx context.Context, version int, dirty bool) erro return tx.Rollback() } - query := `TRUNCATE ` + pq.QuoteIdentifier(defaultMigrationsTable) + query := `truncate ` + pq.QuoteIdentifier(defaultMigrationsTable) if _, err := tx.ExecContext(ctx, query); err != nil { if errRollback := rollback(); errRollback != nil { err = multierror.Append(err, errRollback) @@ -305,8 +318,8 @@ func (p *Postgres) setVersion(ctx context.Context, version int, dirty bool) erro // empty schema Version for failed down migration on the first migration // See: https://github.com/golang-migrate/migrate/issues/330 if version >= 0 || (version == nilVersion && dirty) { - query = `INSERT INTO ` + pq.QuoteIdentifier(defaultMigrationsTable) + - ` (Version, dirty) VALUES ($1, $2)` + query = `insert into ` + pq.QuoteIdentifier(defaultMigrationsTable) + + ` (version, dirty) values ($1, $2)` if _, err := tx.ExecContext(ctx, query, version, dirty); err != nil { if errRollback := rollback(); errRollback != nil { err = multierror.Append(err, errRollback) @@ -323,33 +336,58 @@ func (p *Postgres) setVersion(ctx context.Context, version int, dirty bool) erro return nil } -// CurrentState returns the version, if the database is currently in a dirty state, and any error. -// A version value of -1 indicates no version is set. -func (p *Postgres) CurrentState(ctx context.Context) (version int, dirty bool, err error) { +// CurrentState returns the version, if the database was ever initialized +// previously, if it is currently in a dirty state, and any error. A version +// value of -1 indicates no version is set. +func (p *Postgres) CurrentState(ctx context.Context) (version int, previouslyRan, dirty bool, err error) { const op = "postgres.(Postgres).CurrentState" - query := `SELECT Version, dirty FROM ` + pq.QuoteIdentifier(defaultMigrationsTable) + ` LIMIT 1` - err = p.conn.QueryRowContext(ctx, query).Scan(&version, &dirty) - switch { - case err == sql.ErrNoRows: - return nilVersion, false, nil - - case err != nil: - if e, ok := err.(*pq.Error); ok { - if e.Code.Name() == "undefined_table" { - return nilVersion, false, nil - } - } - return 0, false, errors.Wrap(err, op) - default: - return version, dirty, nil + version = nilVersion + previouslyRan, dirty = false, false + + tableQuery := `select table_name from information_schema.tables where table_schema=(select current_schema()) and table_name in ('schema_migrations', '` + defaultMigrationsTable + `')` + tableResult, err := p.conn.QueryContext(ctx, tableQuery) + if err != nil { + return nilVersion, previouslyRan, dirty, errors.Wrap(err, op) + } + defer tableResult.Close() + if !tableResult.Next() { + // No version table found + return nilVersion, previouslyRan, dirty, nil + } + + tableName := defaultMigrationsTable + if err := tableResult.Scan(&tableName); err != nil { + return nilVersion, previouslyRan, dirty, errors.Wrap(err, op) } + previouslyRan = true + if tableResult.Next() { + return nilVersion, previouslyRan, dirty, errors.New(errors.MigrationIntegrity, op, "both old and new migration tables exist") + } + + query := `select version, dirty from ` + pq.QuoteIdentifier(tableName) + results, err := p.conn.QueryContext(ctx, query) + if err != nil { + return nilVersion, previouslyRan, dirty, errors.Wrap(err, op) + } + defer results.Close() + if !results.Next() { + // no version recorded + return nilVersion, previouslyRan, dirty, nil + } + if err := results.Scan(&version, &dirty); err != nil { + return nilVersion, previouslyRan, dirty, errors.Wrap(err, op) + } + if results.Next() { + return nilVersion, previouslyRan, dirty, errors.New(errors.MigrationIntegrity, op, "to many versions in version table") + } + return version, previouslyRan, dirty, nil } func (p *Postgres) drop(ctx context.Context) (err error) { const op = "postgres.(Postgres).drop" // select all tables in current schema - query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'` + query := `select table_name from information_schema.tables where table_schema=(select current_schema()) and table_type='BASE TABLE'` tables, err := p.conn.QueryContext(ctx, query) if err != nil { return errors.Wrap(err, op) @@ -379,7 +417,7 @@ func (p *Postgres) drop(ctx context.Context) (err error) { if len(tableNames) > 0 { // delete one by one ... for _, t := range tableNames { - query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE` + query = `drop table if exists ` + pq.QuoteIdentifier(t) + ` cascade` if _, err := p.conn.ExecContext(ctx, query); err != nil { return errors.Wrap(err, op) } @@ -389,23 +427,45 @@ func (p *Postgres) drop(ctx context.Context) (err error) { return nil } -// ensureVersionTable checks if versions table exists and, if not, creates it. -// Note that this function locks the database, which deviates from the usual -// convention of "caller locks" in the postgres type. -func (p *Postgres) ensureVersionTable(ctx context.Context) (err error) { - const op = "postgres.(Postgres).ensureVersionTable" - if err = p.TryLock(ctx); err != nil { +// EnsureVersionTable checks if versions table exists and, if not, creates it. +func (p *Postgres) EnsureVersionTable(ctx context.Context) (err error) { + const op = "postgres.(Postgres).EnsureVersionTable" + + var extr execContexter = p.conn + rollback := func() error { return nil } + if p.tx != nil { + extr = p.tx + rollback = func() error { + defer func() { p.tx = nil }() + return p.tx.Rollback() + } + } + + query := `select exists (select 1 from information_schema.tables where table_schema=(select current_schema()) and table_name = '` + defaultMigrationsTable + `');` + exists := false + if err := extr.QueryRowContext(ctx, query).Scan(&exists); err != nil { + if wpErr := rollback(); wpErr != nil { + err = multierror.Append(err, wpErr) + } return errors.Wrap(err, op) } + if exists { + return nil + } - defer func() { - if e := p.Unlock(ctx); e != nil { - err = multierror.Append(err, e) + updateQuery := `alter table if exists schema_migrations rename to ` + defaultMigrationsTable + `;` + if _, err = extr.ExecContext(ctx, updateQuery); err != nil { + if wpErr := rollback(); wpErr != nil { + err = multierror.Append(err, wpErr) } - }() + return errors.Wrap(err, op) + } - query := `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(defaultMigrationsTable) + ` (Version bigint primary key, dirty boolean not null)` - if _, err = p.conn.ExecContext(ctx, query); err != nil { + createStmt := `create table if not exists ` + pq.QuoteIdentifier(defaultMigrationsTable) + ` (version bigint primary key, dirty boolean not null)` + if _, err = extr.ExecContext(ctx, createStmt); err != nil { + if wpErr := rollback(); wpErr != nil { + err = multierror.Append(err, wpErr) + } return errors.Wrap(err, op) } diff --git a/internal/db/schema/postgres/postgres_test.go b/internal/db/schema/postgres/postgres_test.go index b4be6e3433..70fa6023ae 100644 --- a/internal/db/schema/postgres/postgres_test.go +++ b/internal/db/schema/postgres/postgres_test.go @@ -117,7 +117,7 @@ func TestDbStuff(t *testing.T) { }) } -func TestVersion_NoVersionTable(t *testing.T) { +func TestCurrentState_NoVersionTable(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ctx := context.Background() ip, port, err := c.FirstPort() @@ -138,8 +138,43 @@ func TestVersion_NoVersionTable(t *testing.T) { // Drop the version table so calls to CurrentState don't rely on that d.drop(ctx) - v, dirt, err := d.CurrentState(ctx) + v, alreadyRan, dirt, err := d.CurrentState(ctx) assert.NoError(t, err) + assert.False(t, alreadyRan) + assert.Equal(t, v, nilVersion) + assert.False(t, dirt) + }) +} + +func TestCurrentState_ToManyTables(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + d, err := open(t, ctx, addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.close(t); err != nil { + t.Error(err) + } + }() + + // Create the most recent table + d.EnsureVersionTable(ctx) + + // Create the legacy version of the table. + oldTableCreate := `create table if not exists schema_migrations (version bigint primary key, dirty boolean not null)` + _, err = d.conn.ExecContext(ctx, oldTableCreate) + require.NoError(t, err) + v, alreadyRan, dirt, err := d.CurrentState(ctx) + assert.Error(t, err) + assert.True(t, alreadyRan) assert.Equal(t, v, nilVersion) assert.False(t, dirt) }) @@ -163,6 +198,10 @@ func TestMultiStatement(t *testing.T) { t.Error(err) } }() + if err := d.EnsureVersionTable(ctx); err != nil { + t.Fatalf("expected err to be nil, got %v", err) + } + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);"), 2); err != nil { t.Fatalf("expected err to be nil, got %v", err) } @@ -197,12 +236,33 @@ func TestTransaction(t *testing.T) { } }() + v, alreadyRan, dirty, err := d.CurrentState(ctx) + assert.NoError(t, err) + assert.False(t, alreadyRan) + assert.False(t, dirty) + assert.Equal(t, -1, v) + + // Fail the initial setup of the db. + assert.NoError(t, d.StartRun(ctx)) + assert.NoError(t, d.EnsureVersionTable(ctx)) + assert.Error(t, d.Run(ctx, strings.NewReader("SELECT 1 from nonExistantTable"), 3)) + assert.Error(t, d.CommitRun()) + + v, alreadyRan, dirty, err = d.CurrentState(ctx) + assert.NoError(t, err) + assert.False(t, alreadyRan) + assert.False(t, dirty) + assert.Equal(t, -1, v) + assert.NoError(t, d.StartRun(ctx)) + assert.NoError(t, d.EnsureVersionTable(ctx)) assert.NoError(t, d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text);"), 2)) - assert.NoError(t, d.Run(ctx, strings.NewReader("SELECT 1"), 3)) + assert.NoError(t, d.Run(ctx, strings.NewReader("SELECT 1;"), 3)) assert.NoError(t, d.CommitRun()) - v, dirty, err := d.CurrentState(ctx) + + v, alreadyRan, dirty, err = d.CurrentState(ctx) assert.NoError(t, err) + assert.True(t, alreadyRan) assert.False(t, dirty) assert.Equal(t, 3, v) @@ -210,8 +270,10 @@ func TestTransaction(t *testing.T) { assert.NoError(t, d.Run(ctx, strings.NewReader("CREATE TABLE bar (bar text);"), 20)) assert.Error(t, d.Run(ctx, strings.NewReader("SELECT 1 FROM NonExistingTable"), 30)) assert.Error(t, d.CommitRun()) - v, dirty, err = d.CurrentState(ctx) + + v, alreadyRan, dirty, err = d.CurrentState(ctx) assert.NoError(t, err) + assert.True(t, alreadyRan) assert.False(t, dirty) assert.Equal(t, 3, v) }) @@ -221,70 +283,91 @@ func TestWithSchema(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ctx := context.Background() ip, port, err := c.FirstPort() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) addr := pgConnectionString(ip, port) d, err := open(t, ctx, addr) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer func() { if err := d.close(t); err != nil { t.Fatal(err) } }() + require.NoError(t, d.EnsureVersionTable(ctx)) // create foobar schema - if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres"), 1); err != nil { - t.Fatal(err) - } + require.NoError(t, d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres"), 1)) // re-connect using that schema d2, err := open(t, ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foobar", pgPassword, ip, port)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer func() { if err := d2.close(t); err != nil { t.Fatal(err) } }() - version, _, err := d2.CurrentState(ctx) + version, alreadyRan, _, err := d2.CurrentState(ctx) + require.NoError(t, err) + require.Equal(t, nilVersion, version) + assert.False(t, alreadyRan) + + // now update CurrentState and compare + require.NoError(t, d2.EnsureVersionTable(ctx)) + require.NoError(t, d2.setVersion(ctx, 2, false)) + version, alreadyRan, _, err = d2.CurrentState(ctx) + require.NoError(t, err) + require.Equal(t, 2, version) + assert.True(t, alreadyRan) + + // meanwhile, the public schema still has the other CurrentState + version, alreadyRan, _, err = d.CurrentState(ctx) + require.NoError(t, err) + require.Equal(t, 1, version) + assert.True(t, alreadyRan) + }) +} + +func TestPostgres_Lock(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() + ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) } - if version != nilVersion { - t.Fatal("expected NilVersion") - } - // now update CurrentState and compare - if err := d2.setVersion(ctx, 2, false); err != nil { + addr := pgConnectionString(ip, port) + ps, err := open(t, ctx, addr) + if err != nil { t.Fatal(err) } - version, _, err = d2.CurrentState(ctx) + + test(t, ps, []byte("SELECT 1")) + + err = ps.Lock(ctx) if err != nil { t.Fatal(err) } - if version != 2 { - t.Fatal("expected Version 2") + + err = ps.Unlock(ctx) + if err != nil { + t.Fatal(err) } - // meanwhile, the public schema still has the other CurrentState - version, _, err = d.CurrentState(ctx) + err = ps.Lock(ctx) if err != nil { t.Fatal(err) } - if version != 1 { - t.Fatal("expected Version 2") + + err = ps.Unlock(ctx) + if err != nil { + t.Fatal(err) } }) } -func TestPostgres_Lock(t *testing.T) { +func TestEnsureTable_Fresh(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ctx := context.Background() ip, port, err := c.FirstPort() @@ -293,32 +376,118 @@ func TestPostgres_Lock(t *testing.T) { } addr := pgConnectionString(ip, port) - ps, err := open(t, ctx, addr) + p, err := open(t, ctx, addr) if err != nil { - t.Fatal(err) + require.NoError(t, err) } + t.Cleanup(func() { + require.NoError(t, p.close(t)) + }) - test(t, ps, []byte("SELECT 1")) + tableCreated := false + query := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = '" + defaultMigrationsTable + "')" + assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableCreated)) + assert.False(t, tableCreated) - err = ps.Lock(ctx) + assert.NoError(t, p.EnsureVersionTable(ctx)) + assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableCreated)) + assert.True(t, tableCreated) + }) +} + +func TestEnsureTable_ExistingTable(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() + ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) } - err = ps.Unlock(ctx) + addr := pgConnectionString(ip, port) + p, err := open(t, ctx, addr) if err != nil { - t.Fatal(err) + require.NoError(t, err) } + t.Cleanup(func() { + require.NoError(t, p.close(t)) + }) + assert.NoError(t, p.EnsureVersionTable(ctx)) - err = ps.Lock(ctx) + oldTableCreate := `CREATE TABLE IF NOT EXISTS schema_migrations (version bigint primary key, dirty boolean not null)` + _, err = p.db.ExecContext(ctx, oldTableCreate) + assert.NoError(t, err) + + assert.NoError(t, p.EnsureVersionTable(ctx)) + }) +} + +func TestEnsureTable_OldTable(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() + ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) } - err = ps.Unlock(ctx) + addr := pgConnectionString(ip, port) + p, err := open(t, ctx, addr) + if err != nil { + require.NoError(t, err) + } + t.Cleanup(func() { + require.NoError(t, p.close(t)) + }) + + oldTableCreate := `CREATE TABLE IF NOT EXISTS schema_migrations (version bigint primary key, dirty boolean not null)` + _, err = p.db.ExecContext(ctx, oldTableCreate) + assert.NoError(t, err) + + tableExists := false + oldTableCheck := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = 'schema_migrations')" + assert.NoError(t, p.db.QueryRowContext(ctx, oldTableCheck).Scan(&tableExists)) + assert.True(t, tableExists) + + query := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = '" + defaultMigrationsTable + "')" + assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableExists)) + assert.False(t, tableExists) + + assert.NoError(t, p.EnsureVersionTable(ctx)) + + assert.NoError(t, p.db.QueryRowContext(ctx, oldTableCheck).Scan(&tableExists)) + assert.False(t, tableExists) + assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableExists)) + assert.True(t, tableExists) + }) +} + +func TestRollback(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() + ip, port, err := c.FirstPort() if err != nil { t.Fatal(err) } + + addr := pgConnectionString(ip, port) + p, err := open(t, ctx, addr) + if err != nil { + require.NoError(t, err) + } + t.Cleanup(func() { + require.NoError(t, p.close(t)) + }) + + assert.NoError(t, p.StartRun(ctx)) + assert.NoError(t, p.EnsureVersionTable(ctx)) + assert.NoError(t, p.Run(ctx, bytes.NewReader([]byte("create table if not exists foo (foo text)")), 2)) + var exists bool + query := "select exists (select 1 from information_schema.tables where table_name = 'foo' and table_schema = (select current_schema()))" + assert.NoError(t, p.conn.QueryRowContext(context.Background(), query).Scan(&exists)) + assert.True(t, exists) + assert.NoError(t, p.Rollback()) + + assert.NoError(t, p.conn.QueryRowContext(context.Background(), query).Scan(&exists)) + assert.False(t, exists) }) } diff --git a/internal/db/schema/postgres/testing.go b/internal/db/schema/postgres/testing.go index c736ae7a04..42d978d24e 100644 --- a/internal/db/schema/postgres/testing.go +++ b/internal/db/schema/postgres/testing.go @@ -32,11 +32,11 @@ import ( "bytes" "context" "database/sql" - "io" "testing" "time" "github.com/golang-migrate/migrate/v4/database" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -45,24 +45,21 @@ func test(t *testing.T, d *Postgres, migration []byte) { if migration == nil { t.Fatal("test must provide migration reader") } + ctx := context.Background() + + v, alreadyRan, dirty, err := d.CurrentState(ctx) + require.NoError(t, err) + assert.False(t, alreadyRan) + assert.False(t, dirty) + assert.Equal(t, nilVersion, v) - testNilVersion(t, d) // test first testLockAndUnlock(t, d) - testRun(t, d, bytes.NewReader(migration)) + assert.NoError(t, d.EnsureVersionTable(ctx)) + assert.NoError(t, d.Run(ctx, bytes.NewReader(migration), 1)) + testSetVersion(t, d) // also tests CurrentState() // drop breaks the driver, so test it last. - testDrop(t, d) -} - -func testNilVersion(t *testing.T, d *Postgres) { - ctx := context.Background() - v, _, err := d.CurrentState(ctx) - if err != nil { - t.Fatal(err) - } - if v != database.NilVersion { - t.Fatalf("Version: expected Version to be NilVersion (-1), got %v", v) - } + assert.NoError(t, d.drop(ctx)) } func testLockAndUnlock(t *testing.T, d *Postgres) { @@ -71,47 +68,20 @@ func testLockAndUnlock(t *testing.T, d *Postgres) { ctx, _ = context.WithTimeout(ctx, 15*time.Second) // locking twice is ok, no error - if err := d.Lock(ctx); err != nil { - t.Fatalf("got error, expected none: %v", err) - } - if err := d.Lock(ctx); err != nil { - t.Fatalf("got error, expected none: %v", err) - } + require.NoError(t, d.Lock(ctx)) + assert.NoError(t, d.Lock(ctx)) // Unlock - if err := d.Unlock(ctx); err != nil { - t.Fatalf("error unlocking: %v", err) - } + assert.NoError(t, d.Unlock(ctx)) // try to Lock - if err := d.Lock(ctx); err != nil { - t.Fatalf("got error, expected none: %v", err) - } - if err := d.Unlock(ctx); err != nil { - t.Fatalf("got error, expected none: %v", err) - } -} - -func testRun(t *testing.T, d *Postgres, migration io.Reader) { - ctx := context.Background() - if migration == nil { - t.Fatal("migration can't be nil") - } - - if err := d.Run(ctx, migration, 1); err != nil { - t.Fatal(err) - } -} - -func testDrop(t *testing.T, d *Postgres) { - ctx := context.Background() - if err := d.drop(ctx); err != nil { - t.Fatal(err) - } + assert.NoError(t, d.Lock(ctx)) + assert.NoError(t, d.Unlock(ctx)) } func testSetVersion(t *testing.T, d *Postgres) { ctx := context.Background() + require.NoError(t, d.EnsureVersionTable(ctx)) // nolint:maligned testCases := []struct { name string @@ -132,11 +102,12 @@ func testSetVersion(t *testing.T, d *Postgres) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + // d.EnsureVersionTable(ctx) err := d.setVersion(ctx, tc.version, tc.dirty) if err != tc.expectedErr { t.Fatal("Got unexpected error:", err, "!=", tc.expectedErr) } - v, dirty, readErr := d.CurrentState(ctx) + v, _, dirty, readErr := d.CurrentState(ctx) if readErr != tc.expectedReadErr { t.Fatal("Got unexpected error:", readErr, "!=", tc.expectedReadErr) }