test: Refactor to use dbtest package instead of docker

This swaps the test initialization to use the dbtest package so each
test will get a fresh database created from a template database, instead
of a completely new docker container.
pull/1482/head
Timothy Messier 5 years ago
parent af5d1af568
commit 75388f727a
No known key found for this signature in database
GPG Key ID: EFD2F184F7600572

@ -17,6 +17,7 @@ import (
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/observability/event"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/boundary/testing/dbtest"
capoidc "github.com/hashicorp/cap/oidc"
"github.com/hashicorp/go-multierror"
)
@ -39,7 +40,11 @@ func (b *Server) CreateDevDatabase(ctx context.Context, opt ...Option) error {
switch b.DatabaseUrl {
case "":
c, url, container, err = docker.StartDbInDocker(dialect, docker.WithContainerImage(opts.withContainerImage))
if opts.withDatabaseTemplate != "" {
c, url, _, err = dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(opts.withDatabaseTemplate))
} else {
c, url, container, err = docker.StartDbInDocker(dialect, docker.WithContainerImage(opts.withContainerImage))
}
// In case of an error, run the cleanup function. If we pass all errors, c should be set to a noop
// function before returning from this method
defer func() {

@ -28,6 +28,7 @@ type Options struct {
withSkipTargetCreation bool
withContainerImage string
withDialect string
withDatabaseTemplate string
withEventerConfig *event.EventerConfig
withEventFlags *EventFlags
withAttributeFieldPrefix string
@ -149,3 +150,11 @@ func WithStatusCode(statusCode int) Option {
o.withStatusCode = statusCode
}
}
// WithDatabaseTemplate allows for using an existing database template for
// initializing the boundary database.
func WithDatabaseTemplate(template string) Option {
return func(o *Options) {
o.withDatabaseTemplate = template
}
}

@ -6,8 +6,8 @@ import (
"testing"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/testing/dbtest"
"github.com/mitchellh/cli"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -15,7 +15,7 @@ import (
func TestMigrateDatabase(t *testing.T) {
ctx := context.Background()
dialect := "postgres"
dialect := dbtest.Postgres
cases := []struct {
name string
@ -29,7 +29,7 @@ func TestMigrateDatabase(t *testing.T) {
name: "not_initialized_expected_not_intialized",
initialized: false,
urlProvider: func() string {
c, u, _, err := db.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -43,7 +43,7 @@ func TestMigrateDatabase(t *testing.T) {
name: "basic_initialized_expects_initialized",
initialized: true,
urlProvider: func() string {
c, u, _, err := db.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -69,7 +69,7 @@ func TestMigrateDatabase(t *testing.T) {
name: "basic_initialized_expects_not_initialized",
initialized: false,
urlProvider: func() string {
c, u, _, err := db.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -95,7 +95,7 @@ func TestMigrateDatabase(t *testing.T) {
name: "old_version_table_used_intialized",
initialized: true,
urlProvider: func() string {
c, u, _, err := db.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -115,7 +115,7 @@ func TestMigrateDatabase(t *testing.T) {
name: "old_version_table_used_not_intialized",
initialized: false,
urlProvider: func() string {
c, u, _, err := db.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -149,7 +149,7 @@ func TestMigrateDatabase(t *testing.T) {
name: "cant_get_lock_initialized",
initialized: true,
urlProvider: func() string {
c, u, _, err := db.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -172,7 +172,7 @@ func TestMigrateDatabase(t *testing.T) {
name: "cant_get_lock_not_initialized",
initialized: false,
urlProvider: func() string {
c, u, _, err := db.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -210,7 +210,7 @@ func TestVerifyOplogIsEmpty(t *testing.T) {
dialect := "postgres"
ctx := context.Background()
c, u, _, err := db.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())

@ -67,7 +67,7 @@ func testServerCommand(t *testing.T, opts testServerCommandOpts) *Command {
cmd.Server.DevTargetId = defaultTestTargetId
}
err = cmd.CreateDevDatabase(cmd.Context, base.WithContainerImage("postgres"), base.WithSkipOidcAuthMethodCreation())
err = cmd.CreateDevDatabase(cmd.Context, base.WithDatabaseTemplate("boundary_template"), base.WithSkipOidcAuthMethodCreation())
if err != nil {
if cmd.DevDatabaseCleanupFunc != nil {
require.NoError(cmd.DevDatabaseCleanupFunc())

@ -7,7 +7,6 @@ import (
"math"
"time"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/boundary/internal/observability/event"
"github.com/hashicorp/go-hclog"
"github.com/jinzhu/gorm"
@ -23,8 +22,6 @@ func init() {
pq.EnableInfinityTs(NegativeInfinityTS, PositiveInfinityTS)
}
var StartDbInDocker = docker.StartDbInDocker
type DbType int
const (

@ -3,12 +3,13 @@ package db
import (
"testing"
"github.com/hashicorp/boundary/testing/dbtest"
"github.com/lib/pq"
"github.com/stretchr/testify/assert"
)
func TestOpen(t *testing.T) {
cleanup, url, _, err := StartDbInDocker("postgres")
cleanup, url, _, err := dbtest.StartUsingTemplate(dbtest.Postgres)
if err != nil {
t.Fatal(err)
}

@ -8,16 +8,16 @@ import (
"testing"
"github.com/hashicorp/boundary/internal/db/schema/postgres"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/testing/dbtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewManager(t *testing.T) {
dialect := "postgres"
dialect := dbtest.Postgres
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -37,18 +37,15 @@ func TestNewManager(t *testing.T) {
}
func TestCurrentState(t *testing.T) {
dialect := "postgres"
dialect := dbtest.Postgres
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
t.Cleanup(func() {
if err := c(); err != nil {
t.Fatalf("Got error at cleanup: %v", err)
}
})
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
ctx := context.Background()
d, err := sql.Open(dialect, u)
require.NoError(t, err)
@ -79,9 +76,9 @@ func TestCurrentState(t *testing.T) {
}
func TestRollForward(t *testing.T) {
dialect := "postgres"
dialect := dbtest.Postgres
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -104,13 +101,13 @@ func TestRollForward(t *testing.T) {
}
func TestRollForward_NotFromFresh(t *testing.T) {
dialect := "postgres"
dialect := dbtest.Postgres
oState := migrationStates[dialect]
nState := createPartialMigrationState(oState, 8)
migrationStates[dialect] = nState
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -142,9 +139,9 @@ func TestRollForward_NotFromFresh(t *testing.T) {
}
func TestRunMigration_canceledContext(t *testing.T) {
dialect := "postgres"
dialect := dbtest.Postgres
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -163,7 +160,7 @@ func TestRunMigration_canceledContext(t *testing.T) {
}
func TestRollForward_BadSQL(t *testing.T) {
dialect := "postgres"
dialect := dbtest.Postgres
oState := migrationStates[dialect]
defer func() { migrationStates[dialect] = oState }()
@ -172,7 +169,7 @@ func TestRollForward_BadSQL(t *testing.T) {
nState.upMigrations[10] = []byte("SELECT 1 FROM NonExistantTable;")
migrationStates[dialect] = nState
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -194,9 +191,9 @@ func TestRollForward_BadSQL(t *testing.T) {
func TestManager_ExclusiveLock(t *testing.T) {
ctx := context.Background()
dialect := "postgres"
dialect := dbtest.Postgres
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -219,9 +216,9 @@ func TestManager_ExclusiveLock(t *testing.T) {
func TestManager_SharedLock(t *testing.T) {
ctx := context.Background()
dialect := "postgres"
dialect := dbtest.Postgres
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -248,9 +245,9 @@ func TestManager_SharedLock(t *testing.T) {
func Test_GetMigrationLog(t *testing.T) {
t.Parallel()
ctx := context.Background()
dialect := "postgres"
dialect := dbtest.Postgres
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())

@ -7,7 +7,7 @@ import (
"time"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/boundary/testing/dbtest"
"github.com/stretchr/testify/require"
)
@ -40,10 +40,10 @@ func Test_ServerEnumChanges(t *testing.T) {
require := require.New(t)
const priorMigration = 10007
const serverEnumMigration = 11001
dialect := "postgres"
dialect := dbtest.Postgres
ctx := context.Background()
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
require.NoError(err)
t.Cleanup(func() {
require.NoError(c())

@ -7,7 +7,7 @@ import (
"time"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/boundary/testing/dbtest"
"github.com/stretchr/testify/require"
)
@ -34,9 +34,9 @@ func testSetupDb(ctx context.Context, t *testing.T) *sql.DB {
t.Helper()
require := require.New(t)
dialect := "postgres"
dialect := dbtest.Postgres
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
require.NoError(err)
t.Cleanup(func() {
require.NoError(c())

@ -9,12 +9,12 @@ import (
"github.com/hashicorp/boundary/internal/authtoken"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/boundary/internal/host/static"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/session"
"github.com/hashicorp/boundary/internal/target"
"github.com/hashicorp/boundary/testing/dbtest"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
@ -30,9 +30,9 @@ func TestMigrations_UserDimension(t *testing.T) {
t.Parallel()
assert, require := assert.New(t), require.New(t)
ctx := context.Background()
dialect := "postgres"
dialect := dbtest.Postgres
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
require.NoError(err)
t.Cleanup(func() {
require.NoError(c())

@ -8,8 +8,8 @@ import (
"github.com/hashicorp/boundary/internal/auth/password"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/testing/dbtest"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -30,10 +30,10 @@ func Test_PrimaryAuthMethodChanges(t *testing.T) {
//
// 4) asserting some bits about the state of the db.
assert, require := assert.New(t), require.New(t)
dialect := "postgres"
dialect := dbtest.Postgres
ctx := context.Background()
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
require.NoError(err)
t.Cleanup(func() {
require.NoError(c())

@ -18,7 +18,7 @@ func Test_AuthMethodSubtypes(t *testing.T) {
t.Parallel()
assert, require := assert.New(t), require.New(t)
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
conn, _ := db.TestSetup(t, "postgres", db.WithTemplate("template1"))
rw := db.New(conn)
rootWrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, rootWrapper)

@ -38,7 +38,7 @@ func TestMigration(t *testing.T) {
selectQuery = `select session_id, server_id from session_connection_testing order by session_id`
)
conn, _ := db.TestSetup(t, "postgres")
conn, _ := db.TestSetup(t, "postgres", db.WithTemplate("template1"))
db := conn.DB()
_, err := db.Exec(createTables)
require.NoError(err)

@ -33,490 +33,352 @@ package postgres
import (
"bytes"
"context"
"database/sql"
sqldriver "database/sql/driver"
"fmt"
"io"
"log"
"strconv"
"strings"
"testing"
"github.com/dhui/dktest"
"github.com/golang-migrate/migrate/v4/dktesting"
"github.com/hashicorp/boundary/testing/dbtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
pgPassword = "postgres"
)
var (
opts = dktest.Options{
Env: map[string]string{"POSTGRES_PASSWORD": pgPassword},
PortRequired: true, ReadyFunc: isReady,
}
// Supported versions: https://www.postgresql.org/support/versioning/
specs = []dktesting.ContainerSpec{
{ImageName: "postgres:12", Options: opts},
}
)
func pgConnectionString(host, port string) string {
return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?sslmode=disable", pgPassword, host, port)
}
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.FirstPort()
if err != nil {
return false
}
db, err := sql.Open("postgres", pgConnectionString(ip, port))
if err != nil {
return false
}
defer func() {
if err := db.Close(); err != nil {
log.Println("close error:", err)
}
}()
if err = db.PingContext(ctx); err != nil {
switch err {
case sqldriver.ErrBadConn, io.EOF:
return false
default:
log.Println(err)
}
return false
}
return true
}
func TestDbStuff(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
d, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.close(t); err != nil {
t.Error(err)
}
}()
test(t, d, []byte("SELECT 1"))
t.Parallel()
ctx := context.Background()
c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := open(t, ctx, u)
require.NoError(t, err)
test(t, d, []byte("SELECT 1"))
}
func TestCurrentState_NoVersionTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
d, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.close(t); err != nil {
t.Error(err)
}
}()
// Drop the version table so calls to CurrentState don't rely on that
require.NoError(t, d.drop(ctx))
v, alreadyRan, dirt, err := d.CurrentState(ctx)
assert.NoError(t, err)
assert.False(t, alreadyRan)
assert.Equal(t, v, nilVersion)
assert.False(t, dirt)
t.Parallel()
ctx := context.Background()
c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := open(t, ctx, u)
require.NoError(t, err)
// Drop the version table so calls to CurrentState don't rely on that
require.NoError(t, d.drop(ctx))
v, alreadyRan, dirt, err := d.CurrentState(ctx)
assert.NoError(t, err)
assert.False(t, alreadyRan)
assert.Equal(t, v, nilVersion)
assert.False(t, dirt)
}
func TestCurrentState_ToManyTables(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
d, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.close(t); err != nil {
t.Error(err)
}
}()
// Create the most recent table
require.NoError(t, d.EnsureVersionTable(ctx))
// Create the legacy version of the table.
oldTableCreate := `create table if not exists schema_migrations (version bigint primary key, dirty boolean not null)`
_, err = d.conn.ExecContext(ctx, oldTableCreate)
require.NoError(t, err)
v, alreadyRan, dirt, err := d.CurrentState(ctx)
assert.Error(t, err)
assert.True(t, alreadyRan)
assert.Equal(t, v, nilVersion)
assert.False(t, dirt)
t.Parallel()
ctx := context.Background()
c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := open(t, ctx, u)
require.NoError(t, err)
// Create the most recent table
require.NoError(t, d.EnsureVersionTable(ctx))
// Create the legacy version of the table.
oldTableCreate := `create table if not exists schema_migrations (version bigint primary key, dirty boolean not null)`
_, err = d.conn.ExecContext(ctx, oldTableCreate)
require.NoError(t, err)
v, alreadyRan, dirt, err := d.CurrentState(ctx)
assert.Error(t, err)
assert.True(t, alreadyRan)
assert.Equal(t, v, nilVersion)
assert.False(t, dirt)
}
func TestMultiStatement(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
t.Parallel()
ctx := context.Background()
c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := open(t, ctx, u)
require.NoError(t, err)
addr := pgConnectionString(ip, port)
d, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.close(t); err != nil {
t.Error(err)
}
}()
if err := d.EnsureVersionTable(ctx); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
if err := d.EnsureVersionTable(ctx); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);"), 2); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);"), 2); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
// make sure second table exists
var exists bool
if err := d.conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table bar to exist")
}
})
// make sure second table exists
var exists bool
if err := d.conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table bar to exist")
}
}
func TestTransaction(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
d, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.close(t); err != nil {
t.Error(err)
}
}()
v, alreadyRan, dirty, err := d.CurrentState(ctx)
assert.NoError(t, err)
assert.False(t, alreadyRan)
assert.False(t, dirty)
assert.Equal(t, -1, v)
// Fail the initial setup of the db.
assert.NoError(t, d.StartRun(ctx))
assert.NoError(t, d.EnsureVersionTable(ctx))
assert.Error(t, d.Run(ctx, strings.NewReader("SELECT 1 from nonExistantTable"), 3))
assert.Error(t, d.CommitRun())
v, alreadyRan, dirty, err = d.CurrentState(ctx)
assert.NoError(t, err)
assert.False(t, alreadyRan)
assert.False(t, dirty)
assert.Equal(t, -1, v)
assert.NoError(t, d.StartRun(ctx))
assert.NoError(t, d.EnsureVersionTable(ctx))
assert.NoError(t, d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text);"), 2))
assert.NoError(t, d.Run(ctx, strings.NewReader("SELECT 1;"), 3))
assert.NoError(t, d.CommitRun())
v, alreadyRan, dirty, err = d.CurrentState(ctx)
assert.NoError(t, err)
assert.True(t, alreadyRan)
assert.False(t, dirty)
assert.Equal(t, 3, v)
assert.NoError(t, d.StartRun(ctx))
assert.NoError(t, d.Run(ctx, strings.NewReader("CREATE TABLE bar (bar text);"), 20))
assert.Error(t, d.Run(ctx, strings.NewReader("SELECT 1 FROM NonExistingTable"), 30))
assert.Error(t, d.CommitRun())
v, alreadyRan, dirty, err = d.CurrentState(ctx)
assert.NoError(t, err)
assert.True(t, alreadyRan)
assert.False(t, dirty)
assert.Equal(t, 3, v)
t.Parallel()
ctx := context.Background()
c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := open(t, ctx, u)
require.NoError(t, err)
v, alreadyRan, dirty, err := d.CurrentState(ctx)
assert.NoError(t, err)
assert.False(t, alreadyRan)
assert.False(t, dirty)
assert.Equal(t, -1, v)
// Fail the initial setup of the db.
assert.NoError(t, d.StartRun(ctx))
assert.NoError(t, d.EnsureVersionTable(ctx))
assert.Error(t, d.Run(ctx, strings.NewReader("SELECT 1 from nonExistantTable"), 3))
assert.Error(t, d.CommitRun())
v, alreadyRan, dirty, err = d.CurrentState(ctx)
assert.NoError(t, err)
assert.False(t, alreadyRan)
assert.False(t, dirty)
assert.Equal(t, -1, v)
assert.NoError(t, d.StartRun(ctx))
assert.NoError(t, d.EnsureVersionTable(ctx))
assert.NoError(t, d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text);"), 2))
assert.NoError(t, d.Run(ctx, strings.NewReader("SELECT 1;"), 3))
assert.NoError(t, d.CommitRun())
v, alreadyRan, dirty, err = d.CurrentState(ctx)
assert.NoError(t, err)
assert.True(t, alreadyRan)
assert.False(t, dirty)
assert.Equal(t, 3, v)
assert.NoError(t, d.StartRun(ctx))
assert.NoError(t, d.Run(ctx, strings.NewReader("CREATE TABLE bar (bar text);"), 20))
assert.Error(t, d.Run(ctx, strings.NewReader("SELECT 1 FROM NonExistingTable"), 30))
assert.Error(t, d.CommitRun())
v, alreadyRan, dirty, err = d.CurrentState(ctx)
assert.NoError(t, err)
assert.True(t, alreadyRan)
assert.False(t, dirty)
assert.Equal(t, 3, v)
}
func TestWithSchema(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
require.NoError(t, err)
addr := pgConnectionString(ip, port)
d, err := open(t, ctx, addr)
require.NoError(t, err)
defer func() {
if err := d.close(t); err != nil {
t.Fatal(err)
}
}()
require.NoError(t, d.EnsureVersionTable(ctx))
// create foobar schema
require.NoError(t, d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres"), 1))
// re-connect using that schema
d2, err := open(t, ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foobar",
pgPassword, ip, port))
require.NoError(t, err)
defer func() {
if err := d2.close(t); err != nil {
t.Fatal(err)
}
}()
version, alreadyRan, _, err := d2.CurrentState(ctx)
require.NoError(t, err)
require.Equal(t, nilVersion, version)
assert.False(t, alreadyRan)
// now update CurrentState and compare
require.NoError(t, d2.EnsureVersionTable(ctx))
require.NoError(t, d2.setVersion(ctx, 2, false))
version, alreadyRan, _, err = d2.CurrentState(ctx)
require.NoError(t, err)
require.Equal(t, 2, version)
assert.True(t, alreadyRan)
// meanwhile, the public schema still has the other CurrentState
version, alreadyRan, _, err = d.CurrentState(ctx)
require.NoError(t, err)
require.Equal(t, 1, version)
assert.True(t, alreadyRan)
t.Parallel()
ctx := context.Background()
c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
}
d, err := open(t, ctx, u)
require.NoError(t, err)
func TestPostgres_Lock(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
require.NoError(t, d.EnsureVersionTable(ctx))
addr := pgConnectionString(ip, port)
ps, err := open(t, ctx, addr)
if err != nil {
// create foobar schema
require.NoError(t, d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres"), 1))
// re-connect using that schema
d2, err := open(t, ctx, fmt.Sprintf("%s&search_path=foobar", u))
require.NoError(t, err)
defer func() {
if err := d2.close(t); err != nil {
t.Fatal(err)
}
}()
test(t, ps, []byte("SELECT 1"))
version, alreadyRan, _, err := d2.CurrentState(ctx)
require.NoError(t, err)
require.Equal(t, nilVersion, version)
assert.False(t, alreadyRan)
// now update CurrentState and compare
require.NoError(t, d2.EnsureVersionTable(ctx))
require.NoError(t, d2.setVersion(ctx, 2, false))
version, alreadyRan, _, err = d2.CurrentState(ctx)
require.NoError(t, err)
require.Equal(t, 2, version)
assert.True(t, alreadyRan)
// meanwhile, the public schema still has the other CurrentState
version, alreadyRan, _, err = d.CurrentState(ctx)
require.NoError(t, err)
require.Equal(t, 1, version)
assert.True(t, alreadyRan)
}
err = ps.Lock(ctx)
if err != nil {
t.Fatal(err)
}
func TestPostgres_Lock(t *testing.T) {
t.Parallel()
ctx := context.Background()
c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := open(t, ctx, u)
require.NoError(t, err)
err = ps.Unlock(ctx)
if err != nil {
t.Fatal(err)
}
test(t, d, []byte("SELECT 1"))
err = ps.Lock(ctx)
if err != nil {
t.Fatal(err)
}
err = d.Lock(ctx)
if err != nil {
t.Fatal(err)
}
err = ps.Unlock(ctx)
if err != nil {
t.Fatal(err)
}
err = d.Unlock(ctx)
if err != nil {
t.Fatal(err)
}
// make sure we call call Unlock in an idempotent manner.
err = ps.Unlock(ctx)
if err != nil {
t.Fatal(err)
}
})
err = d.Lock(ctx)
if err != nil {
t.Fatal(err)
}
err = d.Unlock(ctx)
if err != nil {
t.Fatal(err)
}
// make sure we call call Unlock in an idempotent manner.
err = d.Unlock(ctx)
if err != nil {
t.Fatal(err)
}
}
func TestEnsureTable_Fresh(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p, err := open(t, ctx, addr)
if err != nil {
require.NoError(t, err)
}
t.Cleanup(func() {
require.NoError(t, p.close(t))
})
t.Parallel()
ctx := context.Background()
c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := open(t, ctx, u)
require.NoError(t, err)
tableCreated := false
query := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = '" + defaultMigrationsTable + "')"
assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableCreated))
assert.False(t, tableCreated)
tableCreated := false
query := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = '" + defaultMigrationsTable + "')"
assert.NoError(t, d.db.QueryRowContext(ctx, query).Scan(&tableCreated))
assert.False(t, tableCreated)
assert.NoError(t, p.EnsureVersionTable(ctx))
assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableCreated))
assert.True(t, tableCreated)
})
assert.NoError(t, d.EnsureVersionTable(ctx))
assert.NoError(t, d.db.QueryRowContext(ctx, query).Scan(&tableCreated))
assert.True(t, tableCreated)
}
func TestEnsureTable_ExistingTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
t.Parallel()
ctx := context.Background()
c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := open(t, ctx, u)
require.NoError(t, err)
addr := pgConnectionString(ip, port)
p, err := open(t, ctx, addr)
if err != nil {
require.NoError(t, err)
}
t.Cleanup(func() {
require.NoError(t, p.close(t))
})
assert.NoError(t, p.EnsureVersionTable(ctx))
assert.NoError(t, d.EnsureVersionTable(ctx))
oldTableCreate := `CREATE TABLE IF NOT EXISTS schema_migrations (version bigint primary key, dirty boolean not null)`
_, err = p.db.ExecContext(ctx, oldTableCreate)
assert.NoError(t, err)
oldTableCreate := `CREATE TABLE IF NOT EXISTS schema_migrations (version bigint primary key, dirty boolean not null)`
_, err = d.db.ExecContext(ctx, oldTableCreate)
assert.NoError(t, err)
assert.NoError(t, p.EnsureVersionTable(ctx))
})
assert.NoError(t, d.EnsureVersionTable(ctx))
}
func TestEnsureTable_OldTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p, err := open(t, ctx, addr)
if err != nil {
require.NoError(t, err)
}
t.Cleanup(func() {
require.NoError(t, p.close(t))
})
t.Parallel()
ctx := context.Background()
c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := open(t, ctx, u)
require.NoError(t, err)
oldTableCreate := `CREATE TABLE IF NOT EXISTS schema_migrations (version bigint primary key, dirty boolean not null)`
_, err = p.db.ExecContext(ctx, oldTableCreate)
assert.NoError(t, err)
oldTableCreate := `CREATE TABLE IF NOT EXISTS schema_migrations (version bigint primary key, dirty boolean not null)`
_, err = d.db.ExecContext(ctx, oldTableCreate)
assert.NoError(t, err)
tableExists := false
oldTableCheck := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = 'schema_migrations')"
assert.NoError(t, p.db.QueryRowContext(ctx, oldTableCheck).Scan(&tableExists))
assert.True(t, tableExists)
tableExists := false
oldTableCheck := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = 'schema_migrations')"
assert.NoError(t, d.db.QueryRowContext(ctx, oldTableCheck).Scan(&tableExists))
assert.True(t, tableExists)
query := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = '" + defaultMigrationsTable + "')"
assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableExists))
assert.False(t, tableExists)
query := "SELECT exists (SELECT 1 FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_name = '" + defaultMigrationsTable + "')"
assert.NoError(t, d.db.QueryRowContext(ctx, query).Scan(&tableExists))
assert.False(t, tableExists)
assert.NoError(t, p.EnsureVersionTable(ctx))
assert.NoError(t, d.EnsureVersionTable(ctx))
assert.NoError(t, p.db.QueryRowContext(ctx, oldTableCheck).Scan(&tableExists))
assert.False(t, tableExists)
assert.NoError(t, p.db.QueryRowContext(ctx, query).Scan(&tableExists))
assert.True(t, tableExists)
})
assert.NoError(t, d.db.QueryRowContext(ctx, oldTableCheck).Scan(&tableExists))
assert.False(t, tableExists)
assert.NoError(t, d.db.QueryRowContext(ctx, query).Scan(&tableExists))
assert.True(t, tableExists)
}
func TestRollback(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p, err := open(t, ctx, addr)
if err != nil {
require.NoError(t, err)
}
t.Cleanup(func() {
require.NoError(t, p.close(t))
})
assert.NoError(t, p.StartRun(ctx))
assert.NoError(t, p.EnsureVersionTable(ctx))
assert.NoError(t, p.Run(ctx, bytes.NewReader([]byte("create table if not exists foo (foo text)")), 2))
var exists bool
query := "select exists (select 1 from information_schema.tables where table_name = 'foo' and table_schema = (select current_schema()))"
assert.NoError(t, p.conn.QueryRowContext(context.Background(), query).Scan(&exists))
assert.True(t, exists)
assert.NoError(t, p.Rollback())
assert.NoError(t, p.conn.QueryRowContext(context.Background(), query).Scan(&exists))
assert.False(t, exists)
t.Parallel()
ctx := context.Background()
c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := open(t, ctx, u)
require.NoError(t, err)
assert.NoError(t, d.StartRun(ctx))
assert.NoError(t, d.EnsureVersionTable(ctx))
assert.NoError(t, d.Run(ctx, bytes.NewReader([]byte("create table if not exists foo (foo text)")), 2))
var exists bool
query := "select exists (select 1 from information_schema.tables where table_name = 'foo' and table_schema = (select current_schema()))"
assert.NoError(t, d.conn.QueryRowContext(context.Background(), query).Scan(&exists))
assert.True(t, exists)
assert.NoError(t, d.Rollback())
assert.NoError(t, d.conn.QueryRowContext(context.Background(), query).Scan(&exists))
assert.False(t, exists)
}
func TestRun_Error(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p, err := open(t, ctx, addr)
if err != nil {
require.NoError(t, err)
}
t.Cleanup(func() {
require.NoError(t, p.close(t))
})
err = p.Run(ctx, bytes.NewReader([]byte("SELECT *\nFROM foo")), 2)
assert.Error(t, err)
t.Parallel()
ctx := context.Background()
c, u, _, err := dbtest.StartUsingTemplate(dbtest.Postgres, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := open(t, ctx, u)
require.NoError(t, err)
err = d.Run(ctx, bytes.NewReader([]byte("SELECT *\nFROM foo")), 2)
assert.Error(t, err)
}
func Test_computeLineFromPos(t *testing.T) {

@ -5,16 +5,16 @@ import (
"database/sql"
"testing"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/boundary/testing/dbtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMigrateStore(t *testing.T) {
dialect := "postgres"
dialect := dbtest.Postgres
ctx := context.Background()
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
@ -45,10 +45,10 @@ func TestMigrateStore(t *testing.T) {
func Test_MigrateStore_WithMigrationStates(t *testing.T) {
assert, require := assert.New(t), require.New(t)
dialect := "postgres"
dialect := dbtest.Postgres
ctx := context.Background()
c, u, _, err := docker.StartDbInDocker(dialect)
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
require.NoError(err)
t.Cleanup(func() {
require.NoError(c())

@ -13,6 +13,7 @@ import (
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/hashicorp/boundary/internal/oplog/store"
"github.com/hashicorp/boundary/testing/dbtest"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/go-kms-wrapping/wrappers/aead"
"github.com/jinzhu/gorm"
@ -30,7 +31,7 @@ func TestSetup(t *testing.T, dialect string, opt ...TestOption) (*gorm.DB, strin
switch opts.withTestDatabaseUrl {
case "":
cleanup, url, _, err = StartDbInDocker(dialect)
cleanup, url, _, err = dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(opts.withTemplate))
if err != nil {
t.Fatal(err)
}
@ -172,6 +173,7 @@ type testOptions struct {
withOperation oplog.OpType
withTestDatabaseUrl string
withResourcePrivateId bool
withTemplate string
}
func getDefaultTestOptions() testOptions {
@ -211,3 +213,11 @@ func WithResourcePrivateId(enable bool) TestOption {
o.withResourcePrivateId = enable
}
}
// WithTemplate provides a way to specify the source database template for creating
// a database.
func WithTemplate(template string) TestOption {
return func(o *testOptions) {
o.withTemplate = template
}
}

@ -7,8 +7,8 @@ import (
"testing"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/boundary/internal/oplog/oplog_test"
"github.com/hashicorp/boundary/testing/dbtest"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/go-kms-wrapping/wrappers/aead"
"github.com/hashicorp/go-uuid"
@ -57,7 +57,7 @@ func testId(t *testing.T) string {
func testInitDbInDocker(t *testing.T) (cleanup func() error, retURL string, err error) {
t.Helper()
cleanup, retURL, _, err = docker.StartDbInDocker("postgres")
cleanup, retURL, _, err = dbtest.StartUsingTemplate(dbtest.Postgres)
if err != nil {
t.Fatal(err)
}

@ -124,7 +124,7 @@ func TestList(t *testing.T) {
return
}
require.NoError(t, gErr)
assert.Empty(t, cmp.Diff(got, tc.res, protocmp.Transform()))
assert.ElementsMatch(t, got.Items, tc.res.Items)
// Test anonymous listing
got, gErr = s.ListCredentialStores(auth.DisabledAuthTestContext(iamRepoFn, tc.req.GetScopeId(), auth.WithUserId(auth.AnonymousUserId)), tc.req)

@ -582,6 +582,7 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
if opts.DisableOidcAuthMethodCreation {
createOpts = append(createOpts, base.WithSkipOidcAuthMethodCreation())
}
createOpts = append(createOpts, base.WithDatabaseTemplate("boundary_template"))
if err := tc.b.CreateDevDatabase(ctx, createOpts...); err != nil {
t.Fatal(err)
}

Loading…
Cancel
Save