From 75388f727a640eaee0656001a5575853707fc482 Mon Sep 17 00:00:00 2001 From: Timothy Messier Date: Fri, 30 Jul 2021 08:49:56 -0400 Subject: [PATCH] test: Refactor to use dbtest package instead of docker This swaps the test initialization to use the dbtest package so each test will get a fresh database created from a template database, instead of a completely new docker container. --- internal/cmd/base/dev.go | 7 +- internal/cmd/base/option.go | 9 + internal/cmd/commands/database/funcs_test.go | 20 +- internal/cmd/commands/server/server_test.go | 2 +- internal/db/db.go | 3 - internal/db/db_test.go | 3 +- internal/db/schema/manager_test.go | 41 +- .../postgres/11/01_server_type_enum_test.go | 6 +- .../12/01_timestamp_sub_funcs_test.go | 6 +- .../postgres/14/warehouse_user_dim_test.go | 6 +- .../migrations/postgres/2/07_iam_test.go | 6 +- .../migrations/postgres/2/10_auth_test.go | 2 +- .../postgres/8/08_connection_test.go | 2 +- internal/db/schema/postgres/postgres_test.go | 690 +++++++----------- internal/db/schema/schema_test.go | 10 +- internal/db/testing.go | 12 +- internal/oplog/testing.go | 4 +- .../credentialstore_service_test.go | 2 +- internal/servers/controller/testing.go | 1 + 19 files changed, 357 insertions(+), 475 deletions(-) diff --git a/internal/cmd/base/dev.go b/internal/cmd/base/dev.go index fa7d7e4f93..4ab7a92675 100644 --- a/internal/cmd/base/dev.go +++ b/internal/cmd/base/dev.go @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/observability/event" "github.com/hashicorp/boundary/internal/types/scope" + "github.com/hashicorp/boundary/testing/dbtest" capoidc "github.com/hashicorp/cap/oidc" "github.com/hashicorp/go-multierror" ) @@ -39,7 +40,11 @@ func (b *Server) CreateDevDatabase(ctx context.Context, opt ...Option) error { switch b.DatabaseUrl { case "": - c, url, container, err = docker.StartDbInDocker(dialect, docker.WithContainerImage(opts.withContainerImage)) + if opts.withDatabaseTemplate != "" { + c, url, _, err = dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(opts.withDatabaseTemplate)) + } else { + c, url, container, err = docker.StartDbInDocker(dialect, docker.WithContainerImage(opts.withContainerImage)) + } // In case of an error, run the cleanup function. If we pass all errors, c should be set to a noop // function before returning from this method defer func() { diff --git a/internal/cmd/base/option.go b/internal/cmd/base/option.go index d83119f537..f528d6e282 100644 --- a/internal/cmd/base/option.go +++ b/internal/cmd/base/option.go @@ -28,6 +28,7 @@ type Options struct { withSkipTargetCreation bool withContainerImage string withDialect string + withDatabaseTemplate string withEventerConfig *event.EventerConfig withEventFlags *EventFlags withAttributeFieldPrefix string @@ -149,3 +150,11 @@ func WithStatusCode(statusCode int) Option { o.withStatusCode = statusCode } } + +// WithDatabaseTemplate allows for using an existing database template for +// initializing the boundary database. +func WithDatabaseTemplate(template string) Option { + return func(o *Options) { + o.withDatabaseTemplate = template + } +} diff --git a/internal/cmd/commands/database/funcs_test.go b/internal/cmd/commands/database/funcs_test.go index 2ff72110ff..fe41b46560 100644 --- a/internal/cmd/commands/database/funcs_test.go +++ b/internal/cmd/commands/database/funcs_test.go @@ -6,8 +6,8 @@ import ( "testing" "github.com/hashicorp/boundary/internal/cmd/base" - "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/schema" + "github.com/hashicorp/boundary/testing/dbtest" "github.com/mitchellh/cli" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -15,7 +15,7 @@ import ( func TestMigrateDatabase(t *testing.T) { ctx := context.Background() - dialect := "postgres" + dialect := dbtest.Postgres cases := []struct { name string @@ -29,7 +29,7 @@ func TestMigrateDatabase(t *testing.T) { name: "not_initialized_expected_not_intialized", initialized: false, urlProvider: func() string { - c, u, _, err := db.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -43,7 +43,7 @@ func TestMigrateDatabase(t *testing.T) { name: "basic_initialized_expects_initialized", initialized: true, urlProvider: func() string { - c, u, _, err := db.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -69,7 +69,7 @@ func TestMigrateDatabase(t *testing.T) { name: "basic_initialized_expects_not_initialized", initialized: false, urlProvider: func() string { - c, u, _, err := db.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -95,7 +95,7 @@ func TestMigrateDatabase(t *testing.T) { name: "old_version_table_used_intialized", initialized: true, urlProvider: func() string { - c, u, _, err := db.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -115,7 +115,7 @@ func TestMigrateDatabase(t *testing.T) { name: "old_version_table_used_not_intialized", initialized: false, urlProvider: func() string { - c, u, _, err := db.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -149,7 +149,7 @@ func TestMigrateDatabase(t *testing.T) { name: "cant_get_lock_initialized", initialized: true, urlProvider: func() string { - c, u, _, err := db.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -172,7 +172,7 @@ func TestMigrateDatabase(t *testing.T) { name: "cant_get_lock_not_initialized", initialized: false, urlProvider: func() string { - c, u, _, err := db.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -210,7 +210,7 @@ func TestVerifyOplogIsEmpty(t *testing.T) { dialect := "postgres" ctx := context.Background() - c, u, _, err := db.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) diff --git a/internal/cmd/commands/server/server_test.go b/internal/cmd/commands/server/server_test.go index 0b4f81ce73..2c51f3890b 100644 --- a/internal/cmd/commands/server/server_test.go +++ b/internal/cmd/commands/server/server_test.go @@ -67,7 +67,7 @@ func testServerCommand(t *testing.T, opts testServerCommandOpts) *Command { cmd.Server.DevTargetId = defaultTestTargetId } - err = cmd.CreateDevDatabase(cmd.Context, base.WithContainerImage("postgres"), base.WithSkipOidcAuthMethodCreation()) + err = cmd.CreateDevDatabase(cmd.Context, base.WithDatabaseTemplate("boundary_template"), base.WithSkipOidcAuthMethodCreation()) if err != nil { if cmd.DevDatabaseCleanupFunc != nil { require.NoError(cmd.DevDatabaseCleanupFunc()) diff --git a/internal/db/db.go b/internal/db/db.go index d3fd54b57b..f506c98457 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -7,7 +7,6 @@ import ( "math" "time" - "github.com/hashicorp/boundary/internal/docker" "github.com/hashicorp/boundary/internal/observability/event" "github.com/hashicorp/go-hclog" "github.com/jinzhu/gorm" @@ -23,8 +22,6 @@ func init() { pq.EnableInfinityTs(NegativeInfinityTS, PositiveInfinityTS) } -var StartDbInDocker = docker.StartDbInDocker - type DbType int const ( diff --git a/internal/db/db_test.go b/internal/db/db_test.go index 75fc507588..7329ceb17b 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -3,12 +3,13 @@ package db import ( "testing" + "github.com/hashicorp/boundary/testing/dbtest" "github.com/lib/pq" "github.com/stretchr/testify/assert" ) func TestOpen(t *testing.T) { - cleanup, url, _, err := StartDbInDocker("postgres") + cleanup, url, _, err := dbtest.StartUsingTemplate(dbtest.Postgres) if err != nil { t.Fatal(err) } diff --git a/internal/db/schema/manager_test.go b/internal/db/schema/manager_test.go index 3213bcb131..84b522d563 100644 --- a/internal/db/schema/manager_test.go +++ b/internal/db/schema/manager_test.go @@ -8,16 +8,16 @@ import ( "testing" "github.com/hashicorp/boundary/internal/db/schema/postgres" - "github.com/hashicorp/boundary/internal/docker" "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/testing/dbtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewManager(t *testing.T) { - dialect := "postgres" + dialect := dbtest.Postgres - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -37,18 +37,15 @@ func TestNewManager(t *testing.T) { } func TestCurrentState(t *testing.T) { - dialect := "postgres" + dialect := dbtest.Postgres - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1)) t.Cleanup(func() { if err := c(); err != nil { t.Fatalf("Got error at cleanup: %v", err) } }) require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, c()) - }) ctx := context.Background() d, err := sql.Open(dialect, u) require.NoError(t, err) @@ -79,9 +76,9 @@ func TestCurrentState(t *testing.T) { } func TestRollForward(t *testing.T) { - dialect := "postgres" + dialect := dbtest.Postgres - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -104,13 +101,13 @@ func TestRollForward(t *testing.T) { } func TestRollForward_NotFromFresh(t *testing.T) { - dialect := "postgres" + dialect := dbtest.Postgres oState := migrationStates[dialect] nState := createPartialMigrationState(oState, 8) migrationStates[dialect] = nState - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -142,9 +139,9 @@ func TestRollForward_NotFromFresh(t *testing.T) { } func TestRunMigration_canceledContext(t *testing.T) { - dialect := "postgres" + dialect := dbtest.Postgres - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -163,7 +160,7 @@ func TestRunMigration_canceledContext(t *testing.T) { } func TestRollForward_BadSQL(t *testing.T) { - dialect := "postgres" + dialect := dbtest.Postgres oState := migrationStates[dialect] defer func() { migrationStates[dialect] = oState }() @@ -172,7 +169,7 @@ func TestRollForward_BadSQL(t *testing.T) { nState.upMigrations[10] = []byte("SELECT 1 FROM NonExistantTable;") migrationStates[dialect] = nState - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -194,9 +191,9 @@ func TestRollForward_BadSQL(t *testing.T) { func TestManager_ExclusiveLock(t *testing.T) { ctx := context.Background() - dialect := "postgres" + dialect := dbtest.Postgres - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -219,9 +216,9 @@ func TestManager_ExclusiveLock(t *testing.T) { func TestManager_SharedLock(t *testing.T) { ctx := context.Background() - dialect := "postgres" + dialect := dbtest.Postgres - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -248,9 +245,9 @@ func TestManager_SharedLock(t *testing.T) { func Test_GetMigrationLog(t *testing.T) { t.Parallel() ctx := context.Background() - dialect := "postgres" + dialect := dbtest.Postgres - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) diff --git a/internal/db/schema/migrations/postgres/11/01_server_type_enum_test.go b/internal/db/schema/migrations/postgres/11/01_server_type_enum_test.go index bd77068411..6d8779519b 100644 --- a/internal/db/schema/migrations/postgres/11/01_server_type_enum_test.go +++ b/internal/db/schema/migrations/postgres/11/01_server_type_enum_test.go @@ -7,7 +7,7 @@ import ( "time" "github.com/hashicorp/boundary/internal/db/schema" - "github.com/hashicorp/boundary/internal/docker" + "github.com/hashicorp/boundary/testing/dbtest" "github.com/stretchr/testify/require" ) @@ -40,10 +40,10 @@ func Test_ServerEnumChanges(t *testing.T) { require := require.New(t) const priorMigration = 10007 const serverEnumMigration = 11001 - dialect := "postgres" + dialect := dbtest.Postgres ctx := context.Background() - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1)) require.NoError(err) t.Cleanup(func() { require.NoError(c()) diff --git a/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs_test.go b/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs_test.go index 3eda4f90f0..eb9fac8d70 100644 --- a/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs_test.go +++ b/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs_test.go @@ -7,7 +7,7 @@ import ( "time" "github.com/hashicorp/boundary/internal/db/schema" - "github.com/hashicorp/boundary/internal/docker" + "github.com/hashicorp/boundary/testing/dbtest" "github.com/stretchr/testify/require" ) @@ -34,9 +34,9 @@ func testSetupDb(ctx context.Context, t *testing.T) *sql.DB { t.Helper() require := require.New(t) - dialect := "postgres" + dialect := dbtest.Postgres - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1)) require.NoError(err) t.Cleanup(func() { require.NoError(c()) diff --git a/internal/db/schema/migrations/postgres/14/warehouse_user_dim_test.go b/internal/db/schema/migrations/postgres/14/warehouse_user_dim_test.go index cb7410b492..1b40812f2b 100644 --- a/internal/db/schema/migrations/postgres/14/warehouse_user_dim_test.go +++ b/internal/db/schema/migrations/postgres/14/warehouse_user_dim_test.go @@ -9,12 +9,12 @@ import ( "github.com/hashicorp/boundary/internal/authtoken" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/schema" - "github.com/hashicorp/boundary/internal/docker" "github.com/hashicorp/boundary/internal/host/static" "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/session" "github.com/hashicorp/boundary/internal/target" + "github.com/hashicorp/boundary/testing/dbtest" wrapping "github.com/hashicorp/go-kms-wrapping" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" @@ -30,9 +30,9 @@ func TestMigrations_UserDimension(t *testing.T) { t.Parallel() assert, require := assert.New(t), require.New(t) ctx := context.Background() - dialect := "postgres" + dialect := dbtest.Postgres - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1)) require.NoError(err) t.Cleanup(func() { require.NoError(c()) diff --git a/internal/db/schema/migrations/postgres/2/07_iam_test.go b/internal/db/schema/migrations/postgres/2/07_iam_test.go index 97b1dbb695..e24549b33c 100644 --- a/internal/db/schema/migrations/postgres/2/07_iam_test.go +++ b/internal/db/schema/migrations/postgres/2/07_iam_test.go @@ -8,8 +8,8 @@ import ( "github.com/hashicorp/boundary/internal/auth/password" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/schema" - "github.com/hashicorp/boundary/internal/docker" "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/testing/dbtest" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -30,10 +30,10 @@ func Test_PrimaryAuthMethodChanges(t *testing.T) { // // 4) asserting some bits about the state of the db. assert, require := assert.New(t), require.New(t) - dialect := "postgres" + dialect := dbtest.Postgres ctx := context.Background() - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1)) require.NoError(err) t.Cleanup(func() { require.NoError(c()) diff --git a/internal/db/schema/migrations/postgres/2/10_auth_test.go b/internal/db/schema/migrations/postgres/2/10_auth_test.go index 8d7fc74b49..9fe69827bc 100644 --- a/internal/db/schema/migrations/postgres/2/10_auth_test.go +++ b/internal/db/schema/migrations/postgres/2/10_auth_test.go @@ -18,7 +18,7 @@ func Test_AuthMethodSubtypes(t *testing.T) { t.Parallel() assert, require := assert.New(t), require.New(t) ctx := context.Background() - conn, _ := db.TestSetup(t, "postgres") + conn, _ := db.TestSetup(t, "postgres", db.WithTemplate("template1")) rw := db.New(conn) rootWrapper := db.TestWrapper(t) kmsCache := kms.TestKms(t, conn, rootWrapper) diff --git a/internal/db/schema/migrations/postgres/8/08_connection_test.go b/internal/db/schema/migrations/postgres/8/08_connection_test.go index bdbb09324b..35367ce9f1 100644 --- a/internal/db/schema/migrations/postgres/8/08_connection_test.go +++ b/internal/db/schema/migrations/postgres/8/08_connection_test.go @@ -38,7 +38,7 @@ func TestMigration(t *testing.T) { selectQuery = `select session_id, server_id from session_connection_testing order by session_id` ) - conn, _ := db.TestSetup(t, "postgres") + conn, _ := db.TestSetup(t, "postgres", db.WithTemplate("template1")) db := conn.DB() _, err := db.Exec(createTables) require.NoError(err) diff --git a/internal/db/schema/postgres/postgres_test.go b/internal/db/schema/postgres/postgres_test.go index b56c8010f7..2b16c73170 100644 --- a/internal/db/schema/postgres/postgres_test.go +++ b/internal/db/schema/postgres/postgres_test.go @@ -33,490 +33,352 @@ package postgres import ( "bytes" "context" - "database/sql" - sqldriver "database/sql/driver" "fmt" - "io" - "log" "strconv" "strings" "testing" - "github.com/dhui/dktest" - "github.com/golang-migrate/migrate/v4/dktesting" + "github.com/hashicorp/boundary/testing/dbtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -const ( - pgPassword = "postgres" -) - -var ( - opts = dktest.Options{ - Env: map[string]string{"POSTGRES_PASSWORD": pgPassword}, - PortRequired: true, ReadyFunc: isReady, - } - // Supported versions: https://www.postgresql.org/support/versioning/ - specs = []dktesting.ContainerSpec{ - {ImageName: "postgres:12", Options: opts}, - } -) - -func pgConnectionString(host, port string) string { - return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?sslmode=disable", pgPassword, host, port) -} - -func isReady(ctx context.Context, c dktest.ContainerInfo) bool { - ip, port, err := c.FirstPort() - if err != nil { - return false - } - - db, err := sql.Open("postgres", pgConnectionString(ip, port)) - if err != nil { - return false - } - defer func() { - if err := db.Close(); err != nil { - log.Println("close error:", err) - } - }() - if err = db.PingContext(ctx); err != nil { - switch err { - case sqldriver.ErrBadConn, io.EOF: - return false - default: - log.Println(err) - } - return false - } - - return true -} - func TestDbStuff(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ctx := context.Background() - ip, port, err := c.FirstPort() - if err != nil { - t.Fatal(err) - } - - addr := pgConnectionString(ip, port) - d, err := open(t, ctx, addr) - if err != nil { - t.Fatal(err) - } - defer func() { - if err := d.close(t); err != nil { - t.Error(err) - } - }() - test(t, d, []byte("SELECT 1")) + t.Parallel() + ctx := context.Background() + c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) }) + d, err := open(t, ctx, u) + require.NoError(t, err) + + test(t, d, []byte("SELECT 1")) } func TestCurrentState_NoVersionTable(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ctx := context.Background() - ip, port, err := c.FirstPort() - if err != nil { - t.Fatal(err) - } - - addr := pgConnectionString(ip, port) - d, err := open(t, ctx, addr) - if err != nil { - t.Fatal(err) - } - defer func() { - if err := d.close(t); err != nil { - t.Error(err) - } - }() - // Drop the version table so calls to CurrentState don't rely on that - require.NoError(t, d.drop(ctx)) - - v, alreadyRan, dirt, err := d.CurrentState(ctx) - assert.NoError(t, err) - assert.False(t, alreadyRan) - assert.Equal(t, v, nilVersion) - assert.False(t, dirt) + t.Parallel() + ctx := context.Background() + c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) }) + d, err := open(t, ctx, u) + require.NoError(t, err) + + // Drop the version table so calls to CurrentState don't rely on that + require.NoError(t, d.drop(ctx)) + v, alreadyRan, dirt, err := d.CurrentState(ctx) + assert.NoError(t, err) + assert.False(t, alreadyRan) + assert.Equal(t, v, nilVersion) + assert.False(t, dirt) } func TestCurrentState_ToManyTables(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ctx := context.Background() - ip, port, err := c.FirstPort() - if err != nil { - t.Fatal(err) - } - - addr := pgConnectionString(ip, port) - d, err := open(t, ctx, addr) - if err != nil { - t.Fatal(err) - } - defer func() { - if err := d.close(t); err != nil { - t.Error(err) - } - }() - - // Create the most recent table - require.NoError(t, d.EnsureVersionTable(ctx)) - - // Create the legacy version of the table. - oldTableCreate := `create table if not exists schema_migrations (version bigint primary key, dirty boolean not null)` - _, err = d.conn.ExecContext(ctx, oldTableCreate) - require.NoError(t, err) - v, alreadyRan, dirt, err := d.CurrentState(ctx) - assert.Error(t, err) - assert.True(t, alreadyRan) - assert.Equal(t, v, nilVersion) - assert.False(t, dirt) + t.Parallel() + ctx := context.Background() + c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) }) + d, err := open(t, ctx, u) + require.NoError(t, err) + + // Create the most recent table + require.NoError(t, d.EnsureVersionTable(ctx)) + + // Create the legacy version of the table. + oldTableCreate := `create table if not exists schema_migrations (version bigint primary key, dirty boolean not null)` + _, err = d.conn.ExecContext(ctx, oldTableCreate) + require.NoError(t, err) + v, alreadyRan, dirt, err := d.CurrentState(ctx) + assert.Error(t, err) + assert.True(t, alreadyRan) + assert.Equal(t, v, nilVersion) + assert.False(t, dirt) } func TestMultiStatement(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ctx := context.Background() - ip, port, err := c.FirstPort() - if err != nil { - t.Fatal(err) - } + t.Parallel() + ctx := context.Background() + c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + d, err := open(t, ctx, u) + require.NoError(t, err) - addr := pgConnectionString(ip, port) - d, err := open(t, ctx, addr) - if err != nil { - t.Fatal(err) - } - defer func() { - if err := d.close(t); err != nil { - t.Error(err) - } - }() - if err := d.EnsureVersionTable(ctx); err != nil { - t.Fatalf("expected err to be nil, got %v", err) - } + if err := d.EnsureVersionTable(ctx); err != nil { + t.Fatalf("expected err to be nil, got %v", err) + } - if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);"), 2); err != nil { - t.Fatalf("expected err to be nil, got %v", err) - } + if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);"), 2); err != nil { + t.Fatalf("expected err to be nil, got %v", err) + } - // make sure second table exists - var exists bool - if err := d.conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil { - t.Fatal(err) - } - if !exists { - t.Fatalf("expected table bar to exist") - } - }) + // make sure second table exists + var exists bool + if err := d.conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil { + t.Fatal(err) + } + if !exists { + t.Fatalf("expected table bar to exist") + } } func TestTransaction(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ctx := context.Background() - ip, port, err := c.FirstPort() - if err != nil { - t.Fatal(err) - } - - addr := pgConnectionString(ip, port) - d, err := open(t, ctx, addr) - if err != nil { - t.Fatal(err) - } - defer func() { - if err := d.close(t); err != nil { - t.Error(err) - } - }() - - v, alreadyRan, dirty, err := d.CurrentState(ctx) - assert.NoError(t, err) - assert.False(t, alreadyRan) - assert.False(t, dirty) - assert.Equal(t, -1, v) - - // Fail the initial setup of the db. - assert.NoError(t, d.StartRun(ctx)) - assert.NoError(t, d.EnsureVersionTable(ctx)) - assert.Error(t, d.Run(ctx, strings.NewReader("SELECT 1 from nonExistantTable"), 3)) - assert.Error(t, d.CommitRun()) - - v, alreadyRan, dirty, err = d.CurrentState(ctx) - assert.NoError(t, err) - assert.False(t, alreadyRan) - assert.False(t, dirty) - assert.Equal(t, -1, v) - - assert.NoError(t, d.StartRun(ctx)) - assert.NoError(t, d.EnsureVersionTable(ctx)) - assert.NoError(t, d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text);"), 2)) - assert.NoError(t, d.Run(ctx, strings.NewReader("SELECT 1;"), 3)) - assert.NoError(t, d.CommitRun()) - - v, alreadyRan, dirty, err = d.CurrentState(ctx) - assert.NoError(t, err) - assert.True(t, alreadyRan) - assert.False(t, dirty) - assert.Equal(t, 3, v) - - assert.NoError(t, d.StartRun(ctx)) - assert.NoError(t, d.Run(ctx, strings.NewReader("CREATE TABLE bar (bar text);"), 20)) - assert.Error(t, d.Run(ctx, strings.NewReader("SELECT 1 FROM NonExistingTable"), 30)) - assert.Error(t, d.CommitRun()) - - v, alreadyRan, dirty, err = d.CurrentState(ctx) - assert.NoError(t, err) - assert.True(t, alreadyRan) - assert.False(t, dirty) - assert.Equal(t, 3, v) + t.Parallel() + ctx := context.Background() + c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) }) + d, err := open(t, ctx, u) + require.NoError(t, err) + + v, alreadyRan, dirty, err := d.CurrentState(ctx) + assert.NoError(t, err) + assert.False(t, alreadyRan) + assert.False(t, dirty) + assert.Equal(t, -1, v) + + // Fail the initial setup of the db. + assert.NoError(t, d.StartRun(ctx)) + assert.NoError(t, d.EnsureVersionTable(ctx)) + assert.Error(t, d.Run(ctx, strings.NewReader("SELECT 1 from nonExistantTable"), 3)) + assert.Error(t, d.CommitRun()) + + v, alreadyRan, dirty, err = d.CurrentState(ctx) + assert.NoError(t, err) + assert.False(t, alreadyRan) + assert.False(t, dirty) + assert.Equal(t, -1, v) + + assert.NoError(t, d.StartRun(ctx)) + assert.NoError(t, d.EnsureVersionTable(ctx)) + assert.NoError(t, d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text);"), 2)) + assert.NoError(t, d.Run(ctx, strings.NewReader("SELECT 1;"), 3)) + assert.NoError(t, d.CommitRun()) + + v, alreadyRan, dirty, err = d.CurrentState(ctx) + assert.NoError(t, err) + assert.True(t, alreadyRan) + assert.False(t, dirty) + assert.Equal(t, 3, v) + + assert.NoError(t, d.StartRun(ctx)) + assert.NoError(t, d.Run(ctx, strings.NewReader("CREATE TABLE bar (bar text);"), 20)) + assert.Error(t, d.Run(ctx, strings.NewReader("SELECT 1 FROM NonExistingTable"), 30)) + assert.Error(t, d.CommitRun()) + + v, alreadyRan, dirty, err = d.CurrentState(ctx) + assert.NoError(t, err) + assert.True(t, alreadyRan) + assert.False(t, dirty) + assert.Equal(t, 3, v) } func TestWithSchema(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ctx := context.Background() - ip, port, err := c.FirstPort() - require.NoError(t, err) - - addr := pgConnectionString(ip, port) - d, err := open(t, ctx, addr) - require.NoError(t, err) - defer func() { - if err := d.close(t); err != nil { - t.Fatal(err) - } - }() - require.NoError(t, d.EnsureVersionTable(ctx)) - - // create foobar schema - require.NoError(t, d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres"), 1)) - - // re-connect using that schema - d2, err := open(t, ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foobar", - pgPassword, ip, port)) - require.NoError(t, err) - defer func() { - if err := d2.close(t); err != nil { - t.Fatal(err) - } - }() - - version, alreadyRan, _, err := d2.CurrentState(ctx) - require.NoError(t, err) - require.Equal(t, nilVersion, version) - assert.False(t, alreadyRan) - - // now update CurrentState and compare - require.NoError(t, d2.EnsureVersionTable(ctx)) - require.NoError(t, d2.setVersion(ctx, 2, false)) - version, alreadyRan, _, err = d2.CurrentState(ctx) - require.NoError(t, err) - require.Equal(t, 2, version) - assert.True(t, alreadyRan) - - // meanwhile, the public schema still has the other CurrentState - version, alreadyRan, _, err = d.CurrentState(ctx) - require.NoError(t, err) - require.Equal(t, 1, version) - assert.True(t, alreadyRan) + t.Parallel() + ctx := context.Background() + c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) }) -} + d, err := open(t, ctx, u) + require.NoError(t, err) -func TestPostgres_Lock(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ctx := context.Background() - ip, port, err := c.FirstPort() - if err != nil { - t.Fatal(err) - } + require.NoError(t, d.EnsureVersionTable(ctx)) - addr := pgConnectionString(ip, port) - ps, err := open(t, ctx, addr) - if err != nil { + // create foobar schema + require.NoError(t, d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres"), 1)) + + // re-connect using that schema + d2, err := open(t, ctx, fmt.Sprintf("%s&search_path=foobar", u)) + require.NoError(t, err) + defer func() { + if err := d2.close(t); err != nil { t.Fatal(err) } + }() - test(t, ps, []byte("SELECT 1")) + version, alreadyRan, _, err := d2.CurrentState(ctx) + require.NoError(t, err) + require.Equal(t, nilVersion, version) + assert.False(t, alreadyRan) + + // now update CurrentState and compare + require.NoError(t, d2.EnsureVersionTable(ctx)) + require.NoError(t, d2.setVersion(ctx, 2, false)) + version, alreadyRan, _, err = d2.CurrentState(ctx) + require.NoError(t, err) + require.Equal(t, 2, version) + assert.True(t, alreadyRan) + + // meanwhile, the public schema still has the other CurrentState + version, alreadyRan, _, err = d.CurrentState(ctx) + require.NoError(t, err) + require.Equal(t, 1, version) + assert.True(t, alreadyRan) +} - err = ps.Lock(ctx) - if err != nil { - t.Fatal(err) - } +func TestPostgres_Lock(t *testing.T) { + t.Parallel() + ctx := context.Background() + c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + d, err := open(t, ctx, u) + require.NoError(t, err) - err = ps.Unlock(ctx) - if err != nil { - t.Fatal(err) - } + test(t, d, []byte("SELECT 1")) - err = ps.Lock(ctx) - if err != nil { - t.Fatal(err) - } + err = d.Lock(ctx) + if err != nil { + t.Fatal(err) + } - err = ps.Unlock(ctx) - if err != nil { - t.Fatal(err) - } + err = d.Unlock(ctx) + if err != nil { + t.Fatal(err) + } - // make sure we call call Unlock in an idempotent manner. - err = ps.Unlock(ctx) - if err != nil { - t.Fatal(err) - } - }) + err = d.Lock(ctx) + if err != nil { + t.Fatal(err) + } + + err = d.Unlock(ctx) + if err != nil { + t.Fatal(err) + } + + // make sure we call call Unlock in an idempotent manner. + err = d.Unlock(ctx) + if err != nil { + t.Fatal(err) + } } func TestEnsureTable_Fresh(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ctx := context.Background() - ip, port, err := c.FirstPort() - if err != nil { - t.Fatal(err) - } - - addr := pgConnectionString(ip, port) - p, err := open(t, ctx, addr) - if err != nil { - require.NoError(t, err) - } - t.Cleanup(func() { - require.NoError(t, p.close(t)) - }) + t.Parallel() + ctx := context.Background() + c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + d, err := open(t, ctx, u) + require.NoError(t, err) - tableCreated := false - query := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = '" + defaultMigrationsTable + "')" - assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableCreated)) - assert.False(t, tableCreated) + tableCreated := false + query := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = '" + defaultMigrationsTable + "')" + assert.NoError(t, d.db.QueryRowContext(ctx, query).Scan(&tableCreated)) + assert.False(t, tableCreated) - assert.NoError(t, p.EnsureVersionTable(ctx)) - assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableCreated)) - assert.True(t, tableCreated) - }) + assert.NoError(t, d.EnsureVersionTable(ctx)) + assert.NoError(t, d.db.QueryRowContext(ctx, query).Scan(&tableCreated)) + assert.True(t, tableCreated) } func TestEnsureTable_ExistingTable(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ctx := context.Background() - ip, port, err := c.FirstPort() - if err != nil { - t.Fatal(err) - } + t.Parallel() + ctx := context.Background() + c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + d, err := open(t, ctx, u) + require.NoError(t, err) - addr := pgConnectionString(ip, port) - p, err := open(t, ctx, addr) - if err != nil { - require.NoError(t, err) - } - t.Cleanup(func() { - require.NoError(t, p.close(t)) - }) - assert.NoError(t, p.EnsureVersionTable(ctx)) + assert.NoError(t, d.EnsureVersionTable(ctx)) - oldTableCreate := `CREATE TABLE IF NOT EXISTS schema_migrations (version bigint primary key, dirty boolean not null)` - _, err = p.db.ExecContext(ctx, oldTableCreate) - assert.NoError(t, err) + oldTableCreate := `CREATE TABLE IF NOT EXISTS schema_migrations (version bigint primary key, dirty boolean not null)` + _, err = d.db.ExecContext(ctx, oldTableCreate) + assert.NoError(t, err) - assert.NoError(t, p.EnsureVersionTable(ctx)) - }) + assert.NoError(t, d.EnsureVersionTable(ctx)) } func TestEnsureTable_OldTable(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ctx := context.Background() - ip, port, err := c.FirstPort() - if err != nil { - t.Fatal(err) - } - - addr := pgConnectionString(ip, port) - p, err := open(t, ctx, addr) - if err != nil { - require.NoError(t, err) - } - t.Cleanup(func() { - require.NoError(t, p.close(t)) - }) + t.Parallel() + ctx := context.Background() + c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) + }) + d, err := open(t, ctx, u) + require.NoError(t, err) - oldTableCreate := `CREATE TABLE IF NOT EXISTS schema_migrations (version bigint primary key, dirty boolean not null)` - _, err = p.db.ExecContext(ctx, oldTableCreate) - assert.NoError(t, err) + oldTableCreate := `CREATE TABLE IF NOT EXISTS schema_migrations (version bigint primary key, dirty boolean not null)` + _, err = d.db.ExecContext(ctx, oldTableCreate) + assert.NoError(t, err) - tableExists := false - oldTableCheck := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = 'schema_migrations')" - assert.NoError(t, p.db.QueryRowContext(ctx, oldTableCheck).Scan(&tableExists)) - assert.True(t, tableExists) + tableExists := false + oldTableCheck := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = 'schema_migrations')" + assert.NoError(t, d.db.QueryRowContext(ctx, oldTableCheck).Scan(&tableExists)) + assert.True(t, tableExists) - query := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = '" + defaultMigrationsTable + "')" - assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableExists)) - assert.False(t, tableExists) + query := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = '" + defaultMigrationsTable + "')" + assert.NoError(t, d.db.QueryRowContext(ctx, query).Scan(&tableExists)) + assert.False(t, tableExists) - assert.NoError(t, p.EnsureVersionTable(ctx)) + assert.NoError(t, d.EnsureVersionTable(ctx)) - assert.NoError(t, p.db.QueryRowContext(ctx, oldTableCheck).Scan(&tableExists)) - assert.False(t, tableExists) - assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableExists)) - assert.True(t, tableExists) - }) + assert.NoError(t, d.db.QueryRowContext(ctx, oldTableCheck).Scan(&tableExists)) + assert.False(t, tableExists) + assert.NoError(t, d.db.QueryRowContext(ctx, query).Scan(&tableExists)) + assert.True(t, tableExists) } func TestRollback(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ctx := context.Background() - ip, port, err := c.FirstPort() - if err != nil { - t.Fatal(err) - } - - addr := pgConnectionString(ip, port) - p, err := open(t, ctx, addr) - if err != nil { - require.NoError(t, err) - } - t.Cleanup(func() { - require.NoError(t, p.close(t)) - }) - - assert.NoError(t, p.StartRun(ctx)) - assert.NoError(t, p.EnsureVersionTable(ctx)) - assert.NoError(t, p.Run(ctx, bytes.NewReader([]byte("create table if not exists foo (foo text)")), 2)) - var exists bool - query := "select exists (select 1 from information_schema.tables where table_name = 'foo' and table_schema = (select current_schema()))" - assert.NoError(t, p.conn.QueryRowContext(context.Background(), query).Scan(&exists)) - assert.True(t, exists) - assert.NoError(t, p.Rollback()) - - assert.NoError(t, p.conn.QueryRowContext(context.Background(), query).Scan(&exists)) - assert.False(t, exists) + t.Parallel() + ctx := context.Background() + c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) }) + d, err := open(t, ctx, u) + require.NoError(t, err) + + assert.NoError(t, d.StartRun(ctx)) + assert.NoError(t, d.EnsureVersionTable(ctx)) + assert.NoError(t, d.Run(ctx, bytes.NewReader([]byte("create table if not exists foo (foo text)")), 2)) + var exists bool + query := "select exists (select 1 from information_schema.tables where table_name = 'foo' and table_schema = (select current_schema()))" + assert.NoError(t, d.conn.QueryRowContext(context.Background(), query).Scan(&exists)) + assert.True(t, exists) + assert.NoError(t, d.Rollback()) + + assert.NoError(t, d.conn.QueryRowContext(context.Background(), query).Scan(&exists)) + assert.False(t, exists) } func TestRun_Error(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ctx := context.Background() - ip, port, err := c.FirstPort() - if err != nil { - t.Fatal(err) - } - - addr := pgConnectionString(ip, port) - p, err := open(t, ctx, addr) - if err != nil { - require.NoError(t, err) - } - t.Cleanup(func() { - require.NoError(t, p.close(t)) - }) - - err = p.Run(ctx, bytes.NewReader([]byte("SELECT *\nFROM foo")), 2) - assert.Error(t, err) + t.Parallel() + ctx := context.Background() + c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1)) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, c()) }) + d, err := open(t, ctx, u) + require.NoError(t, err) + + err = d.Run(ctx, bytes.NewReader([]byte("SELECT *\nFROM foo")), 2) + assert.Error(t, err) } func Test_computeLineFromPos(t *testing.T) { diff --git a/internal/db/schema/schema_test.go b/internal/db/schema/schema_test.go index e921c9060f..634e881ef5 100644 --- a/internal/db/schema/schema_test.go +++ b/internal/db/schema/schema_test.go @@ -5,16 +5,16 @@ import ( "database/sql" "testing" - "github.com/hashicorp/boundary/internal/docker" + "github.com/hashicorp/boundary/testing/dbtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMigrateStore(t *testing.T) { - dialect := "postgres" + dialect := dbtest.Postgres ctx := context.Background() - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, c()) @@ -45,10 +45,10 @@ func TestMigrateStore(t *testing.T) { func Test_MigrateStore_WithMigrationStates(t *testing.T) { assert, require := assert.New(t), require.New(t) - dialect := "postgres" + dialect := dbtest.Postgres ctx := context.Background() - c, u, _, err := docker.StartDbInDocker(dialect) + c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1)) require.NoError(err) t.Cleanup(func() { require.NoError(c()) diff --git a/internal/db/testing.go b/internal/db/testing.go index 440f748c02..9156951039 100644 --- a/internal/db/testing.go +++ b/internal/db/testing.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/boundary/internal/db/schema" "github.com/hashicorp/boundary/internal/oplog" "github.com/hashicorp/boundary/internal/oplog/store" + "github.com/hashicorp/boundary/testing/dbtest" wrapping "github.com/hashicorp/go-kms-wrapping" "github.com/hashicorp/go-kms-wrapping/wrappers/aead" "github.com/jinzhu/gorm" @@ -30,7 +31,7 @@ func TestSetup(t *testing.T, dialect string, opt ...TestOption) (*gorm.DB, strin switch opts.withTestDatabaseUrl { case "": - cleanup, url, _, err = StartDbInDocker(dialect) + cleanup, url, _, err = dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(opts.withTemplate)) if err != nil { t.Fatal(err) } @@ -172,6 +173,7 @@ type testOptions struct { withOperation oplog.OpType withTestDatabaseUrl string withResourcePrivateId bool + withTemplate string } func getDefaultTestOptions() testOptions { @@ -211,3 +213,11 @@ func WithResourcePrivateId(enable bool) TestOption { o.withResourcePrivateId = enable } } + +// WithTemplate provides a way to specify the source database template for creating +// a database. +func WithTemplate(template string) TestOption { + return func(o *testOptions) { + o.withTemplate = template + } +} diff --git a/internal/oplog/testing.go b/internal/oplog/testing.go index 054fec0e04..5476b6b912 100644 --- a/internal/oplog/testing.go +++ b/internal/oplog/testing.go @@ -7,8 +7,8 @@ import ( "testing" "github.com/hashicorp/boundary/internal/db/schema" - "github.com/hashicorp/boundary/internal/docker" "github.com/hashicorp/boundary/internal/oplog/oplog_test" + "github.com/hashicorp/boundary/testing/dbtest" wrapping "github.com/hashicorp/go-kms-wrapping" "github.com/hashicorp/go-kms-wrapping/wrappers/aead" "github.com/hashicorp/go-uuid" @@ -57,7 +57,7 @@ func testId(t *testing.T) string { func testInitDbInDocker(t *testing.T) (cleanup func() error, retURL string, err error) { t.Helper() - cleanup, retURL, _, err = docker.StartDbInDocker("postgres") + cleanup, retURL, _, err = dbtest.StartUsingTemplate(dbtest.Postgres) if err != nil { t.Fatal(err) } diff --git a/internal/servers/controller/handlers/credentialstores/credentialstore_service_test.go b/internal/servers/controller/handlers/credentialstores/credentialstore_service_test.go index d152109cd7..041ffcc638 100644 --- a/internal/servers/controller/handlers/credentialstores/credentialstore_service_test.go +++ b/internal/servers/controller/handlers/credentialstores/credentialstore_service_test.go @@ -124,7 +124,7 @@ func TestList(t *testing.T) { return } require.NoError(t, gErr) - assert.Empty(t, cmp.Diff(got, tc.res, protocmp.Transform())) + assert.ElementsMatch(t, got.Items, tc.res.Items) // Test anonymous listing got, gErr = s.ListCredentialStores(auth.DisabledAuthTestContext(iamRepoFn, tc.req.GetScopeId(), auth.WithUserId(auth.AnonymousUserId)), tc.req) diff --git a/internal/servers/controller/testing.go b/internal/servers/controller/testing.go index 7e5258865d..79bfa0ed1f 100644 --- a/internal/servers/controller/testing.go +++ b/internal/servers/controller/testing.go @@ -582,6 +582,7 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { if opts.DisableOidcAuthMethodCreation { createOpts = append(createOpts, base.WithSkipOidcAuthMethodCreation()) } + createOpts = append(createOpts, base.WithDatabaseTemplate("boundary_template")) if err := tc.b.CreateDevDatabase(ctx, createOpts...); err != nil { t.Fatal(err) }