Organize DB Schema Migration Code and DB Init Checks (#842)

Create Schema Manager and propagate contexts into the DB calls downstream.
Add checks at controller startup and database init for correct schema version and schema migration dirty bit being set.
pull/873/head
Todd Knight 5 years ago committed by GitHub
parent c141a050e5
commit ec6151d174

@ -78,7 +78,7 @@ perms-table:
gen: cleangen proto api migrations fmt
migrations:
$(MAKE) --environment-overrides -C internal/db/migrations/genmigrations migrations
$(MAKE) --environment-overrides -C internal/db/schema/migrations/generate migrations
### oplog requires protoc-gen-go v1.20.0 or later
# GO111MODULE=on go get -u github.com/golang/protobuf/protoc-gen-go@v1.40

@ -9,6 +9,7 @@ replace github.com/hashicorp/boundary/sdk => ./sdk
require (
github.com/armon/go-metrics v0.3.5
github.com/bufbuild/buf v0.33.0
github.com/dhui/dktest v0.3.3
github.com/fatih/color v1.10.0
github.com/favadi/protoc-go-inject-tag v1.1.0
github.com/go-bindata/go-bindata/v3 v3.1.3

@ -18,6 +18,7 @@ import (
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/types/scope"
@ -442,7 +443,7 @@ func (b *Server) ConnectToDatabase(dialect string) error {
return nil
}
func (b *Server) CreateDevDatabase(dialect string, opt ...Option) error {
func (b *Server) CreateDevDatabase(ctx context.Context, dialect string, opt ...Option) error {
opts := getOpts(opt...)
var container, url string
@ -470,17 +471,25 @@ func (b *Server) CreateDevDatabase(dialect string, opt ...Option) error {
return fmt.Errorf("unable to start dev database with dialect %s: %w", dialect, err)
}
_, err := db.InitStore(dialect, c, url)
_, err := schema.InitStore(ctx, dialect, url)
if err != nil {
return fmt.Errorf("unable to initialize dev database with dialect %s: %w", dialect, err)
err = fmt.Errorf("unable to initialize dev database with dialect %s: %w", dialect, err)
if c != nil {
err = multierror.Append(err, c())
}
return err
}
b.DevDatabaseCleanupFunc = c
b.DatabaseUrl = url
default:
if _, err := db.InitStore(dialect, c, b.DatabaseUrl); err != nil {
return fmt.Errorf("error initializing store: %w", err)
if _, err := schema.InitStore(ctx, dialect, b.DatabaseUrl); err != nil {
err = fmt.Errorf("error initializing store: %w", err)
if c != nil {
err = multierror.Append(err, c())
}
return err
}
}
@ -492,16 +501,25 @@ func (b *Server) CreateDevDatabase(dialect string, opt ...Option) error {
}
if err := b.ConnectToDatabase(dialect); err != nil {
if c != nil {
err = multierror.Append(err, c())
}
return err
}
b.Database.LogMode(true)
if err := b.CreateGlobalKmsKeys(context.Background()); err != nil {
if err := b.CreateGlobalKmsKeys(ctx); err != nil {
if c != nil {
err = multierror.Append(err, c())
}
return err
}
if _, err := b.CreateInitialLoginRole(context.Background()); err != nil {
if _, err := b.CreateInitialLoginRole(ctx); err != nil {
if c != nil {
err = multierror.Append(err, c())
}
return err
}
@ -512,7 +530,7 @@ func (b *Server) CreateDevDatabase(dialect string, opt ...Option) error {
return nil
}
if _, _, err := b.CreateInitialAuthMethod(context.Background()); err != nil {
if _, _, err := b.CreateInitialAuthMethod(ctx); err != nil {
return err
}
@ -523,7 +541,7 @@ func (b *Server) CreateDevDatabase(dialect string, opt ...Option) error {
return nil
}
if _, _, err := b.CreateInitialScopes(context.Background()); err != nil {
if _, _, err := b.CreateInitialScopes(ctx); err != nil {
return err
}
@ -545,7 +563,7 @@ func (b *Server) CreateDevDatabase(dialect string, opt ...Option) error {
return nil
}
if _, err := b.CreateInitialTarget(context.Background()); err != nil {
if _, err := b.CreateInitialTarget(ctx); err != nil {
return err
}

@ -36,20 +36,14 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) {
Commands = map[string]cli.CommandFactory{
"server": func() (cli.Command, error) {
return &server.Command{
Server: base.NewServer(&base.Command{
UI: serverCmdUi,
ShutdownCh: base.MakeShutdownCh(),
}),
Server: base.NewServer(base.NewCommand(serverCmdUi)),
SighupCh: MakeSighupCh(),
SigUSR2Ch: MakeSigUSR2Ch(),
}, nil
},
"dev": func() (cli.Command, error) {
return &dev.Command{
Server: base.NewServer(&base.Command{
UI: serverCmdUi,
ShutdownCh: base.MakeShutdownCh(),
}),
Server: base.NewServer(base.NewCommand(serverCmdUi)),
SighupCh: MakeSighupCh(),
SigUSR2Ch: MakeSigUSR2Ch(),
}, nil

@ -7,9 +7,7 @@ import (
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/migrations"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/boundary/sdk/wrapper"
wrapping "github.com/hashicorp/go-kms-wrapping"
@ -186,8 +184,10 @@ func (c *InitCommand) Run(args []string) (retCode int) {
}()
}
if migrations.DevMigration != c.flagAllowDevMigrations {
if migrations.DevMigration {
dialect := "postgres"
if schema.DevMigration(dialect) != c.flagAllowDevMigrations {
if schema.DevMigration(dialect) {
c.UI.Error(base.WrapAtLength("This version of the binary has " +
"dev database schema updates which may not be supported in the " +
"next official release. To proceed anyways please use the " +
@ -263,34 +263,63 @@ func (c *InitCommand) Run(args []string) (retCode int) {
return 1
}
migrationUrl, err := config.ParseAddress(migrationUrlToParse)
if err != nil && err != config.ErrNotAUrl {
c.UI.Error(fmt.Errorf("Error parsing migration url: %w", err).Error())
// This database is used to keep an exclusive lock on the database for the
// remainder of the command
dBase, err := sql.Open(dialect, dbaseUrl)
if err != nil {
c.UI.Error(fmt.Errorf("Error establishing db connection for locking: %w", err).Error())
return 1
}
man, err := schema.NewManager(c.Context, dialect, dBase)
if err != nil {
c.UI.Error(fmt.Errorf("Error setting up schema manager for locking: %w", err).Error())
return 1
}
// Core migrations using the migration URL
{
c.srv.DatabaseUrl = strings.TrimSpace(migrationUrl)
ldb, err := sql.Open("postgres", c.srv.DatabaseUrl)
st, err := man.CurrentState(c.Context)
if err != nil {
c.UI.Error(fmt.Errorf("Error opening database to check init status: %w", err).Error())
c.UI.Error(fmt.Errorf("Error getting database state: %w", err).Error())
return 1
}
_, err = ldb.QueryContext(c.Context, "select version from schema_migrations")
switch {
case err == nil:
if base.Format(c.UI) == "table" {
c.UI.Info("Database already initialized.")
return 0
}
case errors.IsMissingTableError(err):
// Doesn't exist so we continue on
default:
c.UI.Error(fmt.Errorf("Error querying database for init status: %w", err).Error())
if st.Dirty {
c.UI.Error(base.WrapAtLength("Database is in a bad initialization " +
"state. Please revert back to the last known good state."))
return 1
}
if st.InitializationStarted {
// TODO: Separate from the "dirty" bit maintained by the schema
// manager maintain a bit which indicates that this full command
// was completed successfully (with all default resources being created).
// Use that bit to determine if a previous init was completed
// successfully or not.
c.UI.Error(base.WrapAtLength("Database has already been " +
"initialized. If the initialization did not complete successfully " +
"please revert the database to its fresh state."))
return 1
}
ran, err := db.InitStore("postgres", nil, c.srv.DatabaseUrl)
}
// This is an advisory locks on the DB which is released when the db session ends.
if err := man.ExclusiveLock(c.Context); err != nil {
c.UI.Error(fmt.Errorf("Error capturing an exclusive lock: %w", err).Error())
return 1
}
defer func() {
if err := man.ExclusiveUnlock(c.Context); err != nil {
c.UI.Error(fmt.Errorf("Unable to release exclusive lock to the database: %w", err).Error())
}
}()
migrationUrl, err := config.ParseAddress(migrationUrlToParse)
if err != nil && err != config.ErrNotAUrl {
c.UI.Error(fmt.Errorf("Error parsing migration url: %w", err).Error())
return 1
}
// Core migrations using the migration URL
{
migrationUrl = strings.TrimSpace(migrationUrl)
ran, err := schema.InitStore(c.Context, dialect, migrationUrl)
if err != nil {
c.UI.Error(fmt.Errorf("Error running database migrations: %w", err).Error())
return 1
@ -308,7 +337,7 @@ func (c *InitCommand) Run(args []string) (retCode int) {
// Everything after is done with normal database URL and is affecting actual data
c.srv.DatabaseUrl = strings.TrimSpace(dbaseUrl)
if err := c.srv.ConnectToDatabase("postgres"); err != nil {
if err := c.srv.ConnectToDatabase(dialect); err != nil {
c.UI.Error(fmt.Errorf("Error connecting to database after migrations: %w", err).Error())
return 1
}

@ -399,7 +399,7 @@ func (c *Command) Run(args []string) int {
if c.flagDisableDatabaseDestruction {
opts = append(opts, base.WithSkipDatabaseDestruction())
}
if err := c.CreateDevDatabase("postgres", opts...); err != nil {
if err := c.CreateDevDatabase(c.Context, "postgres", opts...); err != nil {
if err == docker.ErrDockerUnsupported {
c.UI.Error("Automatically starting a Docker container running Postgres is not currently supported on this platform. Please use -database-url to pass in a URL (or an env var or file reference to a URL) for connecting to an existing empty database.")
return 1
@ -417,7 +417,7 @@ func (c *Command) Run(args []string) int {
return 1
}
c.DatabaseUrl = strings.TrimSpace(dbaseUrl)
if err := c.CreateDevDatabase("postgres"); err != nil {
if err := c.CreateDevDatabase(c.Context, "postgres"); err != nil {
c.UI.Error(fmt.Errorf("Error connecting to database: %w", err).Error())
return 1
}

@ -9,6 +9,7 @@ import (
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/servers/controller"
"github.com/hashicorp/boundary/internal/servers/worker"
"github.com/hashicorp/boundary/sdk/wrapper"
@ -341,6 +342,46 @@ func (c *Command) Run(args []string) int {
c.UI.Error(fmt.Errorf("Error connecting to database: %w", err).Error())
return 1
}
sMan, err := schema.NewManager(c.Context, "postgres", c.Database.DB())
if err != nil {
c.UI.Error(fmt.Errorf("Can't get schema manager: %w.", err).Error())
return 1
}
// This is an advisory locks on the DB which is released when the db session ends.
if err := sMan.SharedLock(c.Context); err != nil {
c.UI.Error(fmt.Errorf("Unable to gain shared access to the database: %w", err).Error())
return 1
}
defer func() {
if err := sMan.SharedUnlock(c.Context); err != nil {
c.UI.Error(fmt.Errorf("Unable to release shared lock to the database: %w", err).Error())
}
}()
ckState, err := sMan.CurrentState(c.Context)
if err != nil {
c.UI.Error(fmt.Errorf("Error checking schema state: %w", err).Error())
return 1
}
if !ckState.InitializationStarted {
c.UI.Error("Database has not been initialized. Please run `boundary database init`.")
return 1
}
if ckState.Dirty {
c.UI.Error(base.WrapAtLength("Database is in a bad state. Please revert the database into the last known good state."))
return 1
}
if ckState.BinarySchemaVersion > ckState.DatabaseSchemaVersion {
// TODO: Add the command to migrate up the schema version once that command exists.
c.UI.Error("Older schema version is than is expected from this binary.")
return 1
}
if ckState.BinarySchemaVersion < ckState.DatabaseSchemaVersion {
c.UI.Error(base.WrapAtLength(fmt.Sprintf("Newer schema version (%d) "+
"than this binary expects. Please use a newer version of the boundary "+
"binary.", ckState.DatabaseSchemaVersion)))
return 1
}
}
defer func() {

@ -1,15 +1,10 @@
package db
import (
"errors"
"fmt"
"os"
"github.com/golang-migrate/migrate/v4"
"github.com/hashicorp/boundary/internal/db/migrations"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
"github.com/jinzhu/gorm"
"github.com/lib/pq"
)
@ -40,66 +35,6 @@ func Open(dbType DbType, connectionUrl string) (*gorm.DB, error) {
return db, nil
}
// Migrate a database schema
func Migrate(connectionUrl string, migrationsDirectory string) error {
if connectionUrl == "" {
return errors.New("connection url is unset")
}
if _, err := os.Stat(migrationsDirectory); os.IsNotExist(err) {
return errors.New("error migrations directory does not exist")
}
// run migrations
m, err := migrate.New(fmt.Sprintf("file://%s", migrationsDirectory), connectionUrl)
if err != nil {
return fmt.Errorf("unable to create migrations: %w", err)
}
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("unable to run migrations: %w", err)
}
return nil
}
// InitStore will execute the migrations needed to initialize the store. It
// returns true if migrations actually ran; false if we were already current.
func InitStore(dialect string, cleanup func() error, url string) (bool, error) {
var mErr *multierror.Error
// run migrations
source, err := migrations.NewMigrationSource(dialect)
if err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("error creating migration driver: %w", err))
if cleanup != nil {
if err := cleanup(); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("error cleaning up from creating driver: %w", err))
}
}
return false, mErr.ErrorOrNil()
}
m, err := migrate.NewWithSourceInstance("httpfs", source, url)
if err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("error creating migrations: %w", err))
if cleanup != nil {
if err := cleanup(); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("error cleaning up from creating migrations: %w", err))
}
}
return false, mErr.ErrorOrNil()
}
if err := m.Up(); err != nil {
if err == migrate.ErrNoChange {
return false, nil
}
mErr = multierror.Append(mErr, fmt.Errorf("error running migrations: %w", err))
if cleanup != nil {
if err := cleanup(); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("error cleaning up from running migrations: %w", err))
}
}
return false, mErr.ErrorOrNil()
}
return true, mErr.ErrorOrNil()
}
func GetGormLogFormatter(log hclog.Logger) func(values ...interface{}) (messages []interface{}) {
return func(values ...interface{}) (messages []interface{}) {
if len(values) > 2 && values[0].(string) == "log" {

@ -58,56 +58,3 @@ func TestOpen(t *testing.T) {
})
}
}
func TestMigrate(t *testing.T) {
cleanup, url, _, err := StartDbInDocker("postgres")
if err != nil {
t.Fatal(err)
}
defer func() {
if err := cleanup(); err != nil {
t.Error(err)
}
}()
type args struct {
connectionUrl string
migrationsDirectory string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "valid",
args: args{
connectionUrl: url,
migrationsDirectory: "migrations/postgres/0",
},
wantErr: false,
},
{
name: "bad-url",
args: args{
connectionUrl: "",
migrationsDirectory: "migrations/postgres/0",
},
wantErr: true,
},
{
name: "bad-dir",
args: args{
connectionUrl: url,
migrationsDirectory: "",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := Migrate(tt.args.connectionUrl, tt.args.migrationsDirectory); (err != nil) != tt.wantErr {
t.Errorf("Migrate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

@ -1,128 +0,0 @@
package migrations
import (
"bytes"
"fmt"
"net/http"
"os"
"sort"
"strings"
"time"
"github.com/golang-migrate/migrate/v4/source"
"github.com/golang-migrate/migrate/v4/source/httpfs"
)
// migrationDriver satisfies the remaining need of the Driver interface, since
// the package uses PartialDriver under the hood
type migrationDriver struct {
dialect string
}
// Open returns the given "file"
func (m *migrationDriver) Open(name string) (http.File, error) {
return newFakeFile(m.dialect, name)
}
// NewMigrationSource creates a source.Driver using httpfs with the given dialect
func NewMigrationSource(dialect string) (source.Driver, error) {
switch dialect {
case "postgres":
default:
return nil, fmt.Errorf("unknown migrations dialect %s", dialect)
}
return httpfs.New(&migrationDriver{dialect}, "migrations")
}
// fakeFile is used to satisfy the http.File interface
type fakeFile struct {
name string
bytes []byte
reader *bytes.Reader
dialect string
}
func newFakeFile(dialect string, name string) (*fakeFile, error) {
var ff *fakeFile
switch dialect {
case "postgres":
ff = postgresMigrations[name]
}
if ff == nil {
return nil, os.ErrNotExist
}
ff.name = strings.TrimPrefix(name, "migrations/")
ff.reader = bytes.NewReader(ff.bytes)
ff.dialect = dialect
return ff, nil
}
func (f *fakeFile) Read(p []byte) (n int, err error) {
return f.reader.Read(p)
}
func (f *fakeFile) Seek(offset int64, whence int) (int64, error) {
return f.reader.Seek(offset, whence)
}
func (f *fakeFile) Close() error { return nil }
// Readdir returns os.FileInfo values, in sorted order, and eliding the
// migrations "dir"
func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) {
// Get the right map
var migrationsMap map[string]*fakeFile
switch f.dialect {
case "postgres":
migrationsMap = postgresMigrations
default:
return nil, fmt.Errorf("unknown database dialect %s", f.dialect)
}
// Sort the keys. May not be necessary but feels nice.
keys := make([]string, 0, len(migrationsMap))
for k := range migrationsMap {
keys = append(keys, k)
}
sort.Strings(keys)
// Create the slice of fileinfo objects to return
ret := make([]os.FileInfo, 0, len(migrationsMap))
for i, v := range keys {
// We need "migrations" in the map for the initial Open call but we
// should not return it as part of the "directory"'s "files".
if v == "migrations" {
continue
}
stat, err := migrationsMap[v].Stat()
if err != nil {
return nil, err
}
ret = append(ret, stat)
if count > 0 && count == i {
break
}
}
return ret, nil
}
// Stat returns a new fakeFileInfo object with the necessary bits
func (f *fakeFile) Stat() (os.FileInfo, error) {
return &fakeFileInfo{
name: f.name,
size: int64(len(f.bytes)),
}, nil
}
// fakeFileInfo satisfies os.FileInfo but represents our fake "files"
type fakeFileInfo struct {
name string
size int64
}
func (f *fakeFileInfo) Name() string { return f.name }
func (f *fakeFileInfo) Size() int64 { return f.size }
func (f *fakeFileInfo) Mode() os.FileMode { return os.ModePerm }
func (f *fakeFileInfo) ModTime() time.Time { return time.Now() }
func (f *fakeFileInfo) IsDir() bool { return false }
func (f *fakeFileInfo) Sys() interface{} { return nil }

@ -1,177 +0,0 @@
package migrations
import (
"os"
"reflect"
"testing"
"github.com/golang-migrate/migrate/v4/source"
"github.com/golang-migrate/migrate/v4/source/httpfs"
"github.com/stretchr/testify/assert"
)
func TestNewMigrationSource(t *testing.T) {
type args struct {
dialect string
}
tests := []struct {
name string
args args
want source.Driver
wantErr bool
}{
{
name: "postgres",
args: args{dialect: "postgres"},
want: func() source.Driver {
d, err := httpfs.New(&migrationDriver{"postgres"}, "migrations")
if err != nil {
t.Errorf("NewMigrationSource() error creating httpfs = %w", err)
}
return d
}(),
wantErr: false,
},
{
name: "no-dialect",
args: args{dialect: ""},
want: nil,
wantErr: true,
},
{
name: "bad-dialect",
args: args{dialect: "rainbows-and-unicorns-db"},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewMigrationSource(tt.args.dialect)
if (err != nil) != tt.wantErr {
t.Errorf("NewMigrationSource() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewMigrationSource() = %v, want %v", got, tt.want)
}
})
}
}
func Test_migrationDriver_Open(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
dialect string
args args
wantErr bool
}{
{
name: "valid-file",
dialect: "postgres",
args: args{name: "migrations/01_domain_types.up.sql"},
wantErr: false,
},
{
name: "bad-file",
dialect: "postgres",
args: args{name: "migrations/unicorns-and-rainbows.up.sql"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &migrationDriver{
dialect: tt.dialect,
}
_, err := m.Open(tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("migrationDriver.Open() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_fakeFile_Read(t *testing.T) {
assert := assert.New(t)
t.Run("valid", func(t *testing.T) {
ff, err := newFakeFile("postgres", "migrations/01_domain_types.up.sql")
assert.NoError(err)
buf := make([]byte, len(ff.bytes))
n, err := ff.Read(buf)
assert.NoError(err)
assert.Equal(len(buf), n)
})
}
func Test_fakeFile_Seek(t *testing.T) {
assert := assert.New(t)
t.Run("valid", func(t *testing.T) {
ff, err := newFakeFile("postgres", "migrations/01_domain_types.up.sql")
assert.NoError(err)
buf := make([]byte, len(ff.bytes))
n, err := ff.Seek(10, 0)
assert.NoError(err)
assert.Equal(int64(10), n)
n2, err := ff.Read(buf)
assert.NoError(err)
assert.Equal(len(ff.bytes)-10, n2)
})
}
func Test_fakeFile_Close(t *testing.T) {
assert := assert.New(t)
t.Run("valid", func(t *testing.T) {
m := &migrationDriver{
dialect: "postgres",
}
f, err := m.Open("migrations/01_domain_types.up.sql")
assert.NoError(err)
err = f.Close()
assert.NoError(err)
})
}
func Test_fakeFile_Stat(t *testing.T) {
assert := assert.New(t)
t.Run("valid", func(t *testing.T) {
name := "migrations/01_domain_types.up.sql"
ff, err := newFakeFile("postgres", name)
assert.NoError(err)
info, err := ff.Stat()
assert.NoError(err)
assert.Equal(ff.name, info.Name())
assert.Equal(int64(len(ff.bytes)), info.Size())
assert.Equal(os.ModePerm, info.Mode())
assert.Equal(false, info.IsDir())
assert.Equal(nil, info.Sys())
})
}
func Test_fakeFile_Readdir(t *testing.T) {
assert := assert.New(t)
t.Run("valid", func(t *testing.T) {
name := "migrations/01_domain_types.up.sql"
ff, err := newFakeFile("postgres", name)
assert.NoError(err)
info, err := ff.Readdir(0)
assert.NoError(err)
assert.NotNil(info)
info, err = ff.Readdir(1)
assert.NoError(err)
assert.NotNil(info)
assert.Equal(1, len(info))
info, err = ff.Readdir(0)
assert.NoError(err)
assert.NotNil(info)
// we don't want to count "migrations", so we're len - 1
assert.Equal(len(postgresMigrations)-1, len(info))
})
}

@ -1,149 +0,0 @@
package main
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"sort"
"strconv"
"strings"
"text/template"
)
// generate looks for migration sql in a directory for the given dialect and
// applies the templates below to the contents of the files, building up a
// migrations map for the dialect
func generate(dialect string) {
baseDir := os.Getenv("GEN_BASEPATH") + "/internal/db/migrations"
dir, err := os.Open(fmt.Sprintf("%s/%s", baseDir, dialect))
if err != nil {
fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err)
os.Exit(1)
}
versions, err := dir.Readdirnames(0)
if err != nil {
fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err)
os.Exit(1)
}
outBuf := bytes.NewBuffer(nil)
valuesBuf := bytes.NewBuffer(nil)
sort.Strings(versions)
isDev := false
largestVer := 0
for _, ver := range versions {
var verVal int
switch ver {
case "dev":
verVal = largestVer + 1
default:
if verVal, err = strconv.Atoi(ver); err != nil {
fmt.Printf("error reading major schema version directory %q. Must be a number or 'dev'\n", ver)
os.Exit(1)
}
if verVal > largestVer {
largestVer = verVal
}
}
dir, err := os.Open(fmt.Sprintf("%s/%s/%s", baseDir, dialect, ver))
if err != nil {
fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err)
os.Exit(1)
}
names, err := dir.Readdirnames(0)
if err != nil {
fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err)
os.Exit(1)
}
if ver == "dev" && len(names) > 0 {
isDev = true
}
sort.Strings(names)
for _, name := range names {
if !strings.HasSuffix(name, ".sql") {
continue
}
contents, err := ioutil.ReadFile(fmt.Sprintf("%s/%s/%s/%s", baseDir, dialect, ver, name))
if err != nil {
fmt.Printf("error opening file %s with dialect %s: %v", name, dialect, err)
os.Exit(1)
}
vName := name
nameParts := strings.SplitN(name, "_", 2)
if len(nameParts) != 2 {
continue
}
nameVer, err := strconv.Atoi(nameParts[0])
if err != nil {
fmt.Printf("Unable to get file version from %q\n", name)
continue
}
vName = fmt.Sprintf("%02d_%s", (verVal*1000)+nameVer, nameParts[1])
if err := migrationsValueTemplate.Execute(valuesBuf, struct {
Name string
Contents string
}{
Name: vName,
Contents: string(contents),
}); err != nil {
fmt.Printf("error executing migrations value template for file %s/%s: %s", ver, name, err)
os.Exit(1)
}
}
}
if err := migrationsTemplate.Execute(outBuf, struct {
Type string
Values string
DevMigration bool
}{
Type: dialect,
Values: valuesBuf.String(),
DevMigration: isDev,
}); err != nil {
fmt.Printf("error executing migrations value template for dialect %s: %s", dialect, err)
os.Exit(1)
}
outFile := fmt.Sprintf("%s/%s.gen.go", baseDir, dialect)
if err := ioutil.WriteFile(outFile, outBuf.Bytes(), 0o644); err != nil {
fmt.Printf("error writing file %q: %v\n", outFile, err)
os.Exit(1)
}
}
var migrationsTemplate = template.Must(template.New("").Parse(
`// Code generated by "make migrations"; DO NOT EDIT.
package migrations
import (
"bytes"
)
// DevMigration is true if the database schema that would be applied by
// InitStore would be from files in the /dev directory which indicates it would
// not be safe to run in a non dev environment.
var DevMigration = {{ .DevMigration }}
var {{ .Type }}Migrations = map[string]*fakeFile{
"migrations": {
name: "migrations",
},
{{ .Values }}
}
`))
var migrationsValueTemplate = template.Must(template.New("").Parse(
`"migrations/{{ .Name }}": {
name: "{{ .Name }}",
bytes: []byte(` + "`\n{{ .Contents }}\n`" + `),
},
`))

@ -0,0 +1,187 @@
package schema
import (
"bytes"
"context"
"database/sql"
"fmt"
"io"
"github.com/hashicorp/boundary/internal/db/schema/postgres"
"github.com/hashicorp/boundary/internal/errors"
)
// driver provides functionality to a database.
type driver interface {
TrySharedLock(context.Context) error
TryLock(context.Context) error
Lock(context.Context) error
Unlock(context.Context) error
UnlockShared(context.Context) error
Run(context.Context, io.Reader) error
// A value of -1 indicates no version is set.
SetVersion(context.Context, int, bool) error
// A value of -1 indicates no version is set.
Version(context.Context) (int, bool, error)
}
// Manager provides a way to run operations and retrieve information regarding
// the underlying boundary database schema.
// Manager is not thread safe.
type Manager struct {
db *sql.DB
driver driver
dialect string
}
// NewManager creates a new schema manager. An error is returned
// if the provided dialect is unrecognized or if the passed in db is unreachable.
func NewManager(ctx context.Context, dialect string, db *sql.DB) (*Manager, error) {
const op = "schema.NewManager"
dbM := Manager{db: db, dialect: dialect}
switch dialect {
case "postgres":
var err error
dbM.driver, err = postgres.New(ctx, db)
if err != nil {
return nil, errors.Wrap(err, op)
}
default:
return nil, errors.New(errors.InvalidParameter, op, fmt.Sprintf("unknown dialect %q", dialect))
}
return &dbM, nil
}
// State contains information regarding the current state of a boundary database's schema.
type State struct {
// InitializationStarted indicates if the current database has already been initialized
// (successfully or not) at least once.
InitializationStarted bool
// Dirty is set to true if the database failed in a previous migration/initialization.
Dirty bool
// DatabaseSchemaVersion is the schema version that is currently running in the database.
DatabaseSchemaVersion int
// BinarySchemaVersion is the schema version which this boundary binary supports.
BinarySchemaVersion int
}
// CurrentState provides the state of the boundary schema contained in the backing database.
func (b *Manager) CurrentState(ctx context.Context) (*State, error) {
dbS := State{
BinarySchemaVersion: BinarySchemaVersion(b.dialect),
}
v, dirty, err := b.driver.Version(ctx)
if err != nil {
return nil, err
}
if v == nilVersion {
return &dbS, nil
}
dbS.InitializationStarted = true
dbS.DatabaseSchemaVersion = v
dbS.Dirty = dirty
return &dbS, nil
}
// SharedLock attempts to obtain a shared lock on the database. This can fail if
// an exclusive lock is already held with the same key. An error is returned if
// a lock was unable to be obtained.
func (b *Manager) SharedLock(ctx context.Context) error {
const op = "schema.(Manager).SharedLock"
if err := b.driver.TrySharedLock(ctx); err != nil {
return errors.Wrap(err, op)
}
return nil
}
// SharedUnlock releases a shared lock on the database. If this
// fails for whatever reason an error is returned. Unlocking a lock
// that is not held is not an error.
func (b *Manager) SharedUnlock(ctx context.Context) error {
const op = "schema.(Manager).SharedUnlock"
if err := b.driver.UnlockShared(ctx); err != nil {
return errors.Wrap(err, op)
}
return nil
}
// ExclusiveLock attempts to obtain an exclusive lock on the database.
// An error is returned if a lock was unable to be obtained.
func (b *Manager) ExclusiveLock(ctx context.Context) error {
const op = "schema.(Manager).ExclusiveLock"
if err := b.driver.TryLock(ctx); err != nil {
return errors.Wrap(err, op)
}
return nil
}
// ExclusiveUnlock releases a shared lock on the database. If this
// fails for whatever reason an error is returned. Unlocking a lock
// that is not held is not an error.
func (b *Manager) ExclusiveUnlock(ctx context.Context) error {
const op = "schema.(Manager).ExclusiveUnlock"
if err := b.driver.Unlock(ctx); err != nil {
return errors.Wrap(err, op)
}
return nil
}
// RollForward updates the database schema to match the latest version known by
// the boundary binary. An error is not returned if the database is already at
// the most recent version.
func (b *Manager) RollForward(ctx context.Context) error {
const op = "schema.(Manager).RollForward"
// Capturing a lock that this session to the db already possesses is okay.
if err := b.driver.Lock(ctx); err != nil {
return errors.Wrap(err, op)
}
defer func() {
b.driver.Unlock(ctx)
}()
curVersion, dirty, err := b.driver.Version(ctx)
if err != nil {
return errors.Wrap(err, op)
}
if dirty {
return errors.New(errors.NotSpecificIntegrity, op, fmt.Sprintf("schema is dirty with version %d", curVersion))
}
sp, err := newStatementProvider(b.dialect, curVersion)
if err != nil {
return errors.Wrap(err, op)
}
return b.runMigrations(ctx, sp)
}
// runMigrations passes migration queries to a database driver and manages
// the version and dirty bit. Cancelation or deadline/timeout is managed
// through the passed in context.
func (b *Manager) runMigrations(ctx context.Context, qp *statementProvider) error {
const op = "schema.(Manager).runMigrations"
for qp.Next() {
select {
case <-ctx.Done():
return errors.Wrap(ctx.Err(), op)
default:
// context is not done yet. Continue on to the next query to execute.
}
// set version with dirty state
if err := b.driver.SetVersion(ctx, qp.Version(), true); err != nil {
return errors.Wrap(err, op)
}
if err := b.driver.Run(ctx, bytes.NewReader(qp.ReadUp())); err != nil {
return errors.Wrap(err, op)
}
// set clean state
if err := b.driver.SetVersion(ctx, qp.Version(), false); err != nil {
return errors.Wrap(err, op)
}
}
return nil
}

@ -0,0 +1,201 @@
package schema
import (
"context"
"database/sql"
"testing"
"github.com/hashicorp/boundary/internal/db/schema/postgres"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/boundary/internal/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewManager(t *testing.T) {
c, u, _, err := docker.StartDbInDocker("postgres")
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := sql.Open("postgres", u)
require.NoError(t, err)
ctx := context.Background()
_, err = NewManager(ctx, "postgres", d)
require.NoError(t, err)
_, err = NewManager(ctx, "unknown", d)
assert.True(t, errors.Match(errors.T(errors.InvalidParameter), err))
d.Close()
_, err = NewManager(ctx, "postgres", d)
assert.True(t, errors.Match(errors.T(errors.Op("schema.NewManager")), err))
}
func TestCurrentState(t *testing.T) {
c, u, _, err := docker.StartDbInDocker("postgres")
t.Cleanup(func() {
if err := c(); err != nil {
t.Fatalf("Got error at cleanup: %v", err)
}
})
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
ctx := context.Background()
d, err := sql.Open("postgres", u)
require.NoError(t, err)
m, err := NewManager(ctx, "postgres", d)
require.NoError(t, err)
want := &State{
BinarySchemaVersion: BinarySchemaVersion("postgres"),
}
s, err := m.CurrentState(ctx)
require.NoError(t, err)
assert.Equal(t, want, s)
testDriver, err := postgres.New(ctx, d)
require.NoError(t, err)
require.NoError(t, testDriver.SetVersion(ctx, 2, true))
want = &State{
InitializationStarted: true,
BinarySchemaVersion: BinarySchemaVersion("postgres"),
Dirty: true,
DatabaseSchemaVersion: 2,
}
s, err = m.CurrentState(ctx)
require.NoError(t, err)
assert.Equal(t, want, s)
}
func TestRollForward(t *testing.T) {
c, u, _, err := docker.StartDbInDocker("postgres")
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := sql.Open("postgres", u)
require.NoError(t, err)
ctx := context.Background()
m, err := NewManager(ctx, "postgres", d)
require.NoError(t, err)
assert.NoError(t, m.RollForward(ctx))
// Now set to dirty at an early version
testDriver, err := postgres.New(ctx, d)
require.NoError(t, err)
testDriver.SetVersion(ctx, 0, true)
assert.Error(t, m.RollForward(ctx))
}
func TestRollForward_NotFromFresh(t *testing.T) {
dialect := "postgres"
oState := migrationStates[dialect]
nState := createPartialMigrationState(oState, 8)
migrationStates[dialect] = nState
c, u, _, err := docker.StartDbInDocker("postgres")
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d, err := sql.Open(dialect, u)
require.NoError(t, err)
// Initialize the DB with only a portion of the current sql scripts.
ctx := context.Background()
m, err := NewManager(ctx, dialect, d)
require.NoError(t, err)
assert.NoError(t, m.RollForward(ctx))
ver, dirty, err := m.driver.Version(ctx)
assert.NoError(t, err)
assert.Equal(t, nState.binarySchemaVersion, ver)
assert.False(t, dirty)
// Restore the full set of sql scripts and roll the rest of the way forward.
migrationStates[dialect] = oState
newM, err := NewManager(ctx, dialect, d)
require.NoError(t, err)
assert.NoError(t, newM.RollForward(ctx))
ver, dirty, err = newM.driver.Version(ctx)
assert.NoError(t, err)
assert.Equal(t, oState.binarySchemaVersion, ver)
assert.False(t, dirty)
}
func TestManager_ExclusiveLock(t *testing.T) {
ctx := context.Background()
c, u, _, err := docker.StartDbInDocker("postgres")
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d1, err := sql.Open("postgres", u)
require.NoError(t, err)
m1, err := NewManager(ctx, "postgres", d1)
require.NoError(t, err)
d2, err := sql.Open("postgres", u)
require.NoError(t, err)
m2, err := NewManager(ctx, "postgres", d2)
require.NoError(t, err)
assert.NoError(t, m1.ExclusiveLock(ctx))
assert.NoError(t, m1.ExclusiveLock(ctx))
assert.Error(t, m2.ExclusiveLock(ctx))
assert.Error(t, m2.SharedLock(ctx))
}
func TestManager_SharedLock(t *testing.T) {
ctx := context.Background()
c, u, _, err := docker.StartDbInDocker("postgres")
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
d1, err := sql.Open("postgres", u)
require.NoError(t, err)
m1, err := NewManager(ctx, "postgres", d1)
require.NoError(t, err)
d2, err := sql.Open("postgres", u)
require.NoError(t, err)
m2, err := NewManager(ctx, "postgres", d2)
require.NoError(t, err)
assert.NoError(t, m1.SharedLock(ctx))
assert.NoError(t, m2.SharedLock(ctx))
assert.NoError(t, m1.SharedLock(ctx))
assert.NoError(t, m2.SharedLock(ctx))
assert.Error(t, m1.ExclusiveLock(ctx))
assert.Error(t, m2.ExclusiveLock(ctx))
}
// Creates a new migrationState only with the versions <= the provided maxVer
func createPartialMigrationState(om migrationState, maxVer int) migrationState {
nState := migrationState{
devMigration: om.devMigration,
upMigrations: make(map[int][]byte),
downMigrations: make(map[int][]byte),
}
for k := range om.upMigrations {
if k > maxVer {
// Don't store any versions past our test version.
continue
}
nState.upMigrations[k] = om.upMigrations[k]
nState.downMigrations[k] = om.downMigrations[k]
if nState.binarySchemaVersion < k {
nState.binarySchemaVersion = k
}
}
return nState
}

@ -0,0 +1,27 @@
# migrations package
This package handles the generation of the database schema in a format that can
be compiled into the boundary binary.
## Organization
* `./generate`: contains the makefile, code, and templates needed to generate the schema info.
* `./postgres`: contains the versioned schema folders. The contents of these folders, except
for `dev` should not be modified.
## Usage
To regenerate the schema information into the format the boundary binary uses run
`make migrations` or `make gen` to recreate all generated code.
The content of the folders under `./postgres` are compiled into the
boundary binary and when the `boundary database init` or `boundary database migrate`
commands are executed they are applied in order of their version.
The `./postgres/dev` directory contains schema files that are under development and
are not included in a release yet and so it is the only directory where additions and
modifications are allowed. When a boundary binary is built when this directory is not
empty a special flag is required to run the `boundary database init` command to indicate
the user is aware that this is a development release and running this command can
result in a completely broken schema and dataloss.
When a new release is made the contents of the `dev` directory are moved into a new
versioned directory.

@ -3,6 +3,6 @@ THIS_FILE := $(lastword $(MAKEFILE_LIST))
migrations:
go run .
goimports -w ${GEN_BASEPATH}/internal/db/migrations
goimports -w ${GEN_BASEPATH}/internal/db/schema
.PHONY: migrations

@ -0,0 +1,158 @@
package main
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"sort"
"strconv"
"strings"
"text/template"
)
// generate looks for migration sql in a directory for the given dialect and
// applies the templates below to the contents of the files, building up a
// migrations map for the dialect
func generate(dialect string) {
baseDir := os.Getenv("GEN_BASEPATH") + "/internal/db/schema"
srcDir := baseDir + "/migrations"
dir, err := os.Open(fmt.Sprintf("%s/%s", srcDir, dialect))
if err != nil {
fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err)
os.Exit(1)
}
versions, err := dir.Readdirnames(0)
if err != nil {
fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err)
os.Exit(1)
}
sort.Strings(versions)
type ContentValues struct {
Name string
Content string
}
var upContents []ContentValues
var downContents []ContentValues
isDev := false
var lRelVer, largestSchemaVersion int
for _, ver := range versions {
var verVal int
switch ver {
case "dev":
verVal = lRelVer + 1
default:
v, err := strconv.Atoi(ver)
if err != nil {
fmt.Printf("error reading major schema version directory %q. Must be a number or 'dev'\n", ver)
os.Exit(1)
}
verVal = v
if verVal > lRelVer {
lRelVer = verVal
}
}
dir, err := os.Open(fmt.Sprintf("%s/%s/%s", srcDir, dialect, ver))
if err != nil {
fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err)
os.Exit(1)
}
names, err := dir.Readdirnames(0)
if err != nil {
fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err)
os.Exit(1)
}
if ver == "dev" && len(names) > 0 {
isDev = true
}
sort.Strings(names)
for _, name := range names {
if !strings.HasSuffix(name, ".sql") {
continue
}
contents, err := ioutil.ReadFile(fmt.Sprintf("%s/%s/%s/%s", srcDir, dialect, ver, name))
if err != nil {
fmt.Printf("error opening file %s with dialect %s: %v", name, dialect, err)
os.Exit(1)
}
nameParts := strings.SplitN(name, "_", 2)
if len(nameParts) != 2 {
continue
}
v, err := strconv.Atoi(nameParts[0])
if err != nil {
fmt.Printf("Unable to get file version from %q\n", name)
continue
}
fullV := (verVal * 1000) + v
if fullV > largestSchemaVersion {
largestSchemaVersion = fullV
}
cv := ContentValues{
Name: fmt.Sprint(fullV),
Content: string(contents),
}
switch {
case strings.Contains(nameParts[1], ".down."):
downContents = append(downContents, cv)
case strings.Contains(nameParts[1], ".up."):
upContents = append(upContents, cv)
}
}
}
// fmt.Printf("Got upcontent: %#v\n\n downcontents: %#v", upContents[:1], downContents[:1])
outBuf := bytes.NewBuffer(nil)
if err := migrationsTemplate.Execute(outBuf, struct {
Type string
UpValues []ContentValues
DownValues []ContentValues
DevMigration bool
BinarySchemaVersion int
}{
Type: dialect,
UpValues: upContents,
DownValues: downContents,
DevMigration: isDev,
BinarySchemaVersion: largestSchemaVersion,
}); err != nil {
fmt.Printf("error executing migrations value template for dialect %s: %s", dialect, err)
os.Exit(1)
}
outFile := fmt.Sprintf("%s/%s_migration.gen.go", baseDir, dialect)
if err := ioutil.WriteFile(outFile, outBuf.Bytes(), 0o644); err != nil {
fmt.Printf("error writing file %q: %v\n", outFile, err)
os.Exit(1)
}
}
var migrationsTemplate = template.Must(template.Must(template.New("Content").Parse(
`{{ .Name }}: []byte(` + "`\n{{ .Content }}\n`" + `),
`)).New("MainPage").Parse(`package schema
// Code generated by "make migrations"; DO NOT EDIT.
func init() {
migrationStates["{{ .Type }}"] = migrationState{
devMigration: {{ .DevMigration }},
binarySchemaVersion: {{ .BinarySchemaVersion }},
upMigrations: map[int][]byte{
{{range .UpValues }}{{ template "Content" . }}{{end}}
},
downMigrations: map[int][]byte{
{{range .DownValues }}{{ template "Content" . }}{{end}}
},
}
}
`))

@ -0,0 +1,356 @@
// The MIT License (MIT)
//
// Original Work
// Copyright (c) 2016 Matthias Kadenbach
// https://github.com/mattes/migrate
//
// Modified Work
// Copyright (c) 2018 Dale Hui
// https://github.com/golang-migrate/migrate
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package postgres
import (
"context"
"database/sql"
"fmt"
"io"
"io/ioutil"
"strconv"
"strings"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/go-multierror"
"github.com/lib/pq"
)
// schemaAccessLockId is a Lock key used to ensure a single boundary binary is operating
// on a postgres server at a time. The value has no meaning and was picked randomly.
const (
schemaAccessLockId int64 = 3865661975
nilVersion = -1
)
var defaultMigrationsTable = "boundary_schema_version"
// Postgres is a driver usable by a boundary schema manager.
type Postgres struct {
// Locking and unlocking need to use the same connection
conn *sql.Conn
db *sql.DB
}
// New returns a postgres pointer with the provided db verified as
// connectable and a version table being initialized.
func New(ctx context.Context, instance *sql.DB) (*Postgres, error) {
const op = "postgres.New"
if err := instance.PingContext(ctx); err != nil {
return nil, errors.Wrap(err, op)
}
conn, err := instance.Conn(ctx)
if err != nil {
return nil, errors.Wrap(err, op)
}
px := &Postgres{
conn: conn,
db: instance,
}
if err := px.ensureVersionTable(ctx); err != nil {
return nil, errors.Wrap(err, op)
}
return px, nil
}
// TrySharedLock attempts to capture a shared lock. If it is not successful it returns an error.
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
func (p *Postgres) TrySharedLock(ctx context.Context) error {
const op = "postgres.(Postgres).TrySharedLock"
r := p.conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock_shared($1)", schemaAccessLockId)
if r.Err() != nil {
return errors.Wrap(r.Err(), op)
}
var gotLock bool
if err := r.Scan(&gotLock); err != nil {
return errors.Wrap(err, op)
}
if !gotLock {
return errors.New(errors.MigrationLock, op, "Lock failed")
}
return nil
}
// TryLock attempts to capture an exclusive lock. If it is not successful it returns an error.
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
func (p *Postgres) TryLock(ctx context.Context) error {
const op = "postgres.(Postgres).TryLock"
r := p.conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", schemaAccessLockId)
if r.Err() != nil {
return errors.Wrap(r.Err(), op)
}
var gotLock bool
if err := r.Scan(&gotLock); err != nil {
return errors.Wrap(err, op)
}
if !gotLock {
return errors.New(errors.MigrationLock, op, "Lock failed")
}
return nil
}
// Lock calls pg_advisory_lock with the provided context and returns an error
// if we were unable to get the lock before the context cancels.
func (p *Postgres) Lock(ctx context.Context) error {
const op = "postgres.(Postgres).Lock"
// This will wait indefinitely until the Lock can be acquired.
query := `SELECT pg_advisory_lock($1)`
if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil {
return errors.Wrap(err, op)
}
return nil
}
// Unlock calls pg_advisory_unlock and returns an error if we were unable to
// release the lock before the context cancels.
func (p *Postgres) Unlock(ctx context.Context) error {
const op = "postgres.(Postgres).Unlock"
query := `SELECT pg_advisory_unlock($1)`
if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil {
return errors.Wrap(err, op)
}
return nil
}
// UnlockShared calls pg_advisory_unlock_shared and returns an error if we were unable to
// release the lock before the context cancels.
func (p *Postgres) UnlockShared(ctx context.Context) error {
const op = "postgres.(Postgres).UnlockShared"
query := `SELECT pg_advisory_unlock_shared($1)`
if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil {
return errors.Wrap(err, op)
}
return nil
}
// Executes the sql provided in the passed in io.Reader. The contents of the reader must
// fit in memory as the full content is read into a string before being passed to the
// backing database.
func (p *Postgres) Run(ctx context.Context, migration io.Reader) error {
const op = "postgres.(Postgres).Run"
migr, err := ioutil.ReadAll(migration)
if err != nil {
return errors.Wrap(err, op)
}
// Run migration
query := string(migr)
if _, err := p.conn.ExecContext(ctx, query); err != nil {
if pgErr, ok := err.(*pq.Error); ok {
var line uint
var col uint
var lineColOK bool
if pgErr.Position != "" {
if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
line, col, lineColOK = computeLineFromPos(query, int(pos))
}
}
message := fmt.Sprintf("migration failed")
if lineColOK {
message = fmt.Sprintf("%s (column %d)", message, col)
}
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
}
message = fmt.Sprintf("%s, on line %v: %s", message, line, migr)
return errors.Wrap(err, op, errors.WithMsg(message))
}
return errors.Wrap(err, op, errors.WithMsg(fmt.Sprintf("migration failed: %s", migr)))
}
return nil
}
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
// replace crlf with lf
s = strings.Replace(s, "\r\n", "\n", -1)
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
runes := []rune(s)
if pos > len(runes) {
return 0, 0, false
}
sel := runes[:pos]
line = uint(runesCount(sel, newLine) + 1)
col = uint(pos - 1 - runesLastIndex(sel, newLine))
return line, col, true
}
const newLine = '\n'
func runesCount(input []rune, target rune) int {
var count int
for _, r := range input {
if r == target {
count++
}
}
return count
}
func runesLastIndex(input []rune, target rune) int {
for i := len(input) - 1; i >= 0; i-- {
if input[i] == target {
return i
}
}
return -1
}
// SetVersion sets the version number, and whether the database is in a dirty state.
// A version value of -1 indicates no version is set.
func (p *Postgres) SetVersion(ctx context.Context, version int, dirty bool) error {
const op = "postgres.(Postgres).SetVersion"
tx, err := p.conn.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return err
}
query := `TRUNCATE ` + pq.QuoteIdentifier(defaultMigrationsTable)
if _, err := tx.ExecContext(ctx, query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return errors.Wrap(err, op)
}
// Also re-write the schema Version for nil dirty versions to prevent
// empty schema Version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == nilVersion && dirty) {
query = `INSERT INTO ` + pq.QuoteIdentifier(defaultMigrationsTable) +
` (Version, dirty) VALUES ($1, $2)`
if _, err := tx.ExecContext(ctx, query, version, dirty); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return errors.Wrap(err, op)
}
}
if err := tx.Commit(); err != nil {
return errors.Wrap(err, op)
}
return nil
}
// Version returns the version, if the database is currently in a dirty state, and any error.
// A version value of -1 indicates no version is set.
func (p *Postgres) Version(ctx context.Context) (version int, dirty bool, err error) {
const op = "postgres.(Postgres).Version"
query := `SELECT Version, dirty FROM ` + pq.QuoteIdentifier(defaultMigrationsTable) + ` LIMIT 1`
err = p.conn.QueryRowContext(ctx, query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return nilVersion, false, nil
case err != nil:
if e, ok := err.(*pq.Error); ok {
if e.Code.Name() == "undefined_table" {
return nilVersion, false, nil
}
}
return 0, false, errors.Wrap(err, op)
default:
return version, dirty, nil
}
}
func (p *Postgres) drop(ctx context.Context) (err error) {
const op = "postgres.(Postgres).drop"
// select all tables in current schema
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
tables, err := p.conn.QueryContext(ctx, query)
if err != nil {
return errors.Wrap(err, op)
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
err = errors.Wrap(err, op)
}
}()
// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return errors.Wrap(err, op)
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if err := tables.Err(); err != nil {
return errors.Wrap(err, op)
}
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
if _, err := p.conn.ExecContext(ctx, query); err != nil {
return errors.Wrap(err, op)
}
}
}
return nil
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the postgres type.
func (p *Postgres) ensureVersionTable(ctx context.Context) (err error) {
const op = "postgres.(Postgres).ensureVersionTable"
if err = p.Lock(ctx); err != nil {
return errors.Wrap(err, op)
}
defer func() {
if e := p.Unlock(ctx); e != nil {
err = multierror.Append(err, e)
}
}()
query := `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(defaultMigrationsTable) + ` (Version bigint primary key, dirty boolean not null)`
if _, err = p.conn.ExecContext(ctx, query); err != nil {
return errors.Wrap(err, op)
}
return nil
}

@ -0,0 +1,433 @@
// The MIT License (MIT)
//
// Original Work
// Copyright (c) 2016 Matthias Kadenbach
// https://github.com/mattes/migrate
//
// Modified Work
// Copyright (c) 2018 Dale Hui
// https://github.com/golang-migrate/migrate
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package postgres
// error codes https://github.com/lib/pq/blob/master/error.go
import (
"bytes"
"context"
"database/sql"
sqldriver "database/sql/driver"
"fmt"
"io"
"log"
"strconv"
"strings"
"sync"
"testing"
"github.com/dhui/dktest"
"github.com/golang-migrate/migrate/v4/dktesting"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
pgPassword = "postgres"
)
var (
opts = dktest.Options{
Env: map[string]string{"POSTGRES_PASSWORD": pgPassword},
PortRequired: true, ReadyFunc: isReady,
}
// Supported versions: https://www.postgresql.org/support/versioning/
specs = []dktesting.ContainerSpec{
{ImageName: "postgres:12", Options: opts},
}
)
func pgConnectionString(host, port string) string {
return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?sslmode=disable", pgPassword, host, port)
}
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.FirstPort()
if err != nil {
return false
}
db, err := sql.Open("postgres", pgConnectionString(ip, port))
if err != nil {
return false
}
defer func() {
if err := db.Close(); err != nil {
log.Println("close error:", err)
}
}()
if err = db.PingContext(ctx); err != nil {
switch err {
case sqldriver.ErrBadConn, io.EOF:
return false
default:
log.Println(err)
}
return false
}
return true
}
func TestDbStuff(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
d, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.close(t); err != nil {
t.Error(err)
}
}()
test(t, d, []byte("SELECT 1"))
})
}
func TestVersion_NoVersionTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
d, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.close(t); err != nil {
t.Error(err)
}
}()
// Drop the version table so calls to Version don't rely on that
d.drop(ctx)
v, dirt, err := d.Version(ctx)
assert.NoError(t, err)
assert.Equal(t, v, nilVersion)
assert.False(t, dirt)
})
}
func TestMultiStatement(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
d, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.close(t); err != nil {
t.Error(err)
}
}()
if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
// make sure second table exists
var exists bool
if err := d.conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table bar to exist")
}
})
}
func TestWithSchema(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
d, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.close(t); err != nil {
t.Fatal(err)
}
}()
// create foobar schema
if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
if err := d.SetVersion(ctx, 1, false); err != nil {
t.Fatal(err)
}
// re-connect using that schema
d2, err := open(t, ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foobar",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d2.close(t); err != nil {
t.Fatal(err)
}
}()
version, _, err := d2.Version(ctx)
if err != nil {
t.Fatal(err)
}
if version != nilVersion {
t.Fatal("expected NilVersion")
}
// now update Version and compare
if err := d2.SetVersion(ctx, 2, false); err != nil {
t.Fatal(err)
}
version, _, err = d2.Version(ctx)
if err != nil {
t.Fatal(err)
}
if version != 2 {
t.Fatal("expected Version 2")
}
// meanwhile, the public schema still has the other Version
version, _, err = d.Version(ctx)
if err != nil {
t.Fatal(err)
}
if version != 1 {
t.Fatal("expected Version 2")
}
})
}
func TestPostgres_Lock(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
ps, err := open(t, ctx, addr)
if err != nil {
t.Fatal(err)
}
test(t, ps, []byte("SELECT 1"))
err = ps.Lock(ctx)
if err != nil {
t.Fatal(err)
}
err = ps.Unlock(ctx)
if err != nil {
t.Fatal(err)
}
err = ps.Lock(ctx)
if err != nil {
t.Fatal(err)
}
err = ps.Unlock(ctx)
if err != nil {
t.Fatal(err)
}
})
}
func TestWithInstance_Concurrent(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
// The number of concurrent processes running New
const concurrency = 30
// We can instantiate a single database handle because it is
// actually a connection pool, and so, each of the below go
// routines will have a high probability of using a separate
// connection, which is something we want to exercise.
db, err := sql.Open("postgres", pgConnectionString(ip, port))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := db.Close(); err != nil {
t.Error(err)
}
}()
db.SetMaxIdleConns(concurrency)
db.SetMaxOpenConns(concurrency)
var wg sync.WaitGroup
defer wg.Wait()
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
go func(i int) {
defer wg.Done()
_, err := New(context.Background(), db)
if err != nil {
t.Errorf("process %d error: %s", i, err)
}
}(i)
}
})
}
func TestRun_Error(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p, err := open(t, ctx, addr)
if err != nil {
require.NoError(t, err)
}
t.Cleanup(func() {
require.NoError(t, p.close(t))
})
err = p.Run(ctx, bytes.NewReader([]byte("SELECT *\nFROM foo")))
assert.Error(t, err)
})
}
func Test_computeLineFromPos(t *testing.T) {
testcases := []struct {
pos int
wantLine uint
wantCol uint
input string
wantOk bool
}{
{
15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
},
{
16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
},
{
25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
},
{
27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
},
{
10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
},
{
11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
},
{
17, 2, 8, "SELECT *\nFROM foo", true, // last character
},
{
18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
},
}
for i, tc := range testcases {
t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
run := func(crlf bool, nonASCII bool) {
var name string
if crlf {
name = "crlf"
} else {
name = "lf"
}
if nonASCII {
name += "-nonascii"
} else {
name += "-ascii"
}
t.Run(name, func(t *testing.T) {
input := tc.input
if crlf {
input = strings.Replace(input, "\n", "\r\n", -1)
}
if nonASCII {
input = strings.Replace(input, "FROM", "FRÖM", -1)
}
gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
if tc.wantOk {
t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
}
if gotOK != tc.wantOk {
t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
}
if gotLine != tc.wantLine {
t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
}
if gotCol != tc.wantCol {
t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
}
})
}
run(false, false)
run(true, false)
run(false, true)
run(true, true)
})
}
}

@ -0,0 +1,169 @@
// The MIT License (MIT)
//
// Original Work
// Copyright (c) 2016 Matthias Kadenbach
// https://github.com/mattes/migrate
//
// Modified Work
// Copyright (c) 2018 Dale Hui
// https://github.com/golang-migrate/migrate
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package postgres
import (
"bytes"
"context"
"database/sql"
"io"
"testing"
"time"
"github.com/golang-migrate/migrate/v4/database"
"github.com/stretchr/testify/require"
)
// test runs tests against database implementations.
func test(t *testing.T, d *Postgres, migration []byte) {
if migration == nil {
t.Fatal("test must provide migration reader")
}
testNilVersion(t, d) // test first
testLockAndUnlock(t, d)
testRun(t, d, bytes.NewReader(migration))
testSetVersion(t, d) // also tests Version()
// drop breaks the driver, so test it last.
testDrop(t, d)
}
func testNilVersion(t *testing.T, d *Postgres) {
ctx := context.Background()
v, _, err := d.Version(ctx)
if err != nil {
t.Fatal(err)
}
if v != database.NilVersion {
t.Fatalf("Version: expected Version to be NilVersion (-1), got %v", v)
}
}
func testLockAndUnlock(t *testing.T, d *Postgres) {
ctx := context.Background()
ctx, _ = context.WithTimeout(ctx, 15*time.Second)
// locking twice is ok, no error
if err := d.Lock(ctx); err != nil {
t.Fatalf("got error, expected none: %v", err)
}
if err := d.Lock(ctx); err != nil {
t.Fatalf("got error, expected none: %v", err)
}
// Unlock
if err := d.Unlock(ctx); err != nil {
t.Fatalf("error unlocking: %v", err)
}
// try to Lock
if err := d.Lock(ctx); err != nil {
t.Fatalf("got error, expected none: %v", err)
}
if err := d.Unlock(ctx); err != nil {
t.Fatalf("got error, expected none: %v", err)
}
}
func testRun(t *testing.T, d *Postgres, migration io.Reader) {
ctx := context.Background()
if migration == nil {
t.Fatal("migration can't be nil")
}
if err := d.Run(ctx, migration); err != nil {
t.Fatal(err)
}
}
func testDrop(t *testing.T, d *Postgres) {
ctx := context.Background()
if err := d.drop(ctx); err != nil {
t.Fatal(err)
}
}
func testSetVersion(t *testing.T, d *Postgres) {
ctx := context.Background()
// nolint:maligned
testCases := []struct {
name string
version int
dirty bool
expectedErr error
expectedReadErr error
expectedVersion int
expectedDirty bool
}{
{name: "set 1 dirty", version: 1, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: 1, expectedDirty: true},
{name: "re-set 1 dirty", version: 1, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: 1, expectedDirty: true},
{name: "set 2 clean", version: 2, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: 2, expectedDirty: false},
{name: "re-set 2 clean", version: 2, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: 2, expectedDirty: false},
{name: "last migration dirty", version: database.NilVersion, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: database.NilVersion, expectedDirty: true},
{name: "last migration clean", version: database.NilVersion, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: database.NilVersion, expectedDirty: false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := d.SetVersion(ctx, tc.version, tc.dirty)
if err != tc.expectedErr {
t.Fatal("Got unexpected error:", err, "!=", tc.expectedErr)
}
v, dirty, readErr := d.Version(ctx)
if readErr != tc.expectedReadErr {
t.Fatal("Got unexpected error:", readErr, "!=", tc.expectedReadErr)
}
if v != tc.expectedVersion {
t.Error("Got unexpected Version:", v, "!=", tc.expectedVersion)
}
if dirty != tc.expectedDirty {
t.Error("Got unexpected dirty value:", dirty, "!=", tc.dirty)
}
})
}
}
func open(t *testing.T, ctx context.Context, u string) (*Postgres, error) {
t.Helper()
db, err := sql.Open("postgres", u)
require.NoError(t, err)
px, err := New(ctx, db)
require.NoError(t, err)
return px, nil
}
func (p *Postgres) close(t *testing.T) error {
t.Helper()
require.NoError(t, p.conn.Close())
require.NoError(t, p.db.Close())
return nil
}

@ -0,0 +1,40 @@
package schema
import (
"context"
"database/sql"
"github.com/hashicorp/boundary/internal/errors"
)
// InitStore executes the migrations needed to initialize the store. It
// returns true if migrations actually ran; false if the database is already current.
func InitStore(ctx context.Context, dialect string, url string) (bool, error) {
const op = "schema.InitStore"
d, err := sql.Open(dialect, url)
if err != nil {
return false, errors.Wrap(err, op)
}
sMan, err := NewManager(ctx, dialect, d)
if err != nil {
return false, errors.Wrap(err, op)
}
st, err := sMan.CurrentState(ctx)
if err != nil {
return false, errors.Wrap(err, op)
}
if st.Dirty {
return false, errors.New(errors.MigrationIntegrity, op, "db marked dirty")
}
if st.InitializationStarted && st.DatabaseSchemaVersion == st.BinarySchemaVersion {
return false, nil
}
if err := sMan.RollForward(ctx); err != nil {
return false, errors.Wrap(err, op)
}
return true, nil
}

@ -0,0 +1,65 @@
package schema
import (
"context"
"database/sql"
"testing"
"github.com/hashicorp/boundary/internal/docker"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInitStore(t *testing.T) {
dialect := "postgres"
ctx := context.Background()
c, u, _, err := docker.StartDbInDocker(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
// Set the possible migration state to be only part of the full migration
oState := migrationStates[dialect]
nState := createPartialMigrationState(oState, 8)
migrationStates[dialect] = nState
ran, err := InitStore(ctx, dialect, u)
assert.NoError(t, err)
assert.True(t, ran)
ran, err = InitStore(ctx, dialect, u)
assert.NoError(t, err)
assert.False(t, ran)
// Reset the possible migration state to contain everything
migrationStates[dialect] = oState
ran, err = InitStore(ctx, dialect, u)
assert.NoError(t, err)
assert.True(t, ran)
ran, err = InitStore(ctx, dialect, u)
assert.NoError(t, err)
assert.False(t, ran)
}
func TestInitStore_Dirty(t *testing.T) {
dialect := "postgres"
ctx := context.Background()
c, u, _, err := docker.StartDbInDocker(dialect)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, c())
})
// Mark the db as dirty indicating a previously run failed migration
db, err := sql.Open(dialect, u)
require.NoError(t, err)
m, err := NewManager(ctx, dialect, db)
m.driver.SetVersion(ctx, -1, true)
b, err := InitStore(ctx, dialect, u)
assert.Error(t, err)
assert.False(t, b)
}

@ -0,0 +1,54 @@
package schema
const nilVersion = -1
// migrationState is meant to be populated by the generated migration code and
// contains the internal representation of a schema in the current binary.
type migrationState struct {
// devMigration is true if the database schema that would be applied by
// InitStore would be from files in the /dev directory which indicates it would
// not be safe to run in a non dev environment.
devMigration bool
// binarySchemaVersion provides the database schema version supported by
// this binary.
binarySchemaVersion int
upMigrations map[int][]byte
downMigrations map[int][]byte
}
// migrationStates is populated by the generated migration code with the key being the dialect.
var migrationStates = make(map[string]migrationState)
func getUpMigration(dialect string) map[int][]byte {
ms, ok := migrationStates[dialect]
if !ok {
return nil
}
return ms.upMigrations
}
func getDownMigration(dialect string) map[int][]byte {
ms, ok := migrationStates[dialect]
if !ok {
return nil
}
return ms.downMigrations
}
// DevMigration returns true iff the provided dialect has changes which are still in development.
func DevMigration(dialect string) bool {
ms, ok := migrationStates[dialect]
return ok && ms.devMigration
}
// BinarySchemaVersion provides the schema version that this binary supports for the provided dialect.
// If the binary doesn't support this dialect -1 is returned.
func BinarySchemaVersion(dialect string) int {
ms, ok := migrationStates[dialect]
if !ok {
return nilVersion
}
return ms.binarySchemaVersion
}

@ -0,0 +1,23 @@
package schema
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestBinarySchemaVersion(t *testing.T) {
dialect := "test_binaryschemaversion"
migrationStates[dialect] = migrationState{binarySchemaVersion: 3}
assert.Equal(t, 3, BinarySchemaVersion(dialect))
assert.Equal(t, nilVersion, BinarySchemaVersion("unknown_dialect"))
}
func TestDevMigration(t *testing.T) {
dialect := "test_devmigrations"
migrationStates[dialect] = migrationState{devMigration: true}
assert.True(t, DevMigration(dialect))
migrationStates[dialect] = migrationState{devMigration: false}
assert.False(t, DevMigration(dialect))
assert.False(t, DevMigration("unknown_dialect"))
}

@ -0,0 +1,59 @@
package schema
import (
"fmt"
"sort"
"github.com/hashicorp/boundary/internal/errors"
)
// statementProvider provides the migration statements in order.
// Next should be called prior to calling Version() or ReadUp() or sentinel
// values (-1 and nil) will be returned.
type statementProvider struct {
pos int
versions []int
up, down map[int][]byte
}
func newStatementProvider(dialect string, curVer int) (*statementProvider, error) {
op := errors.Op("schema.newStatementProvider")
qp := statementProvider{pos: -1}
qp.up, qp.down = getUpMigration(dialect), getDownMigration(dialect)
if len(qp.up) != len(qp.down) {
return nil, errors.New(errors.MigrationIntegrity, op, fmt.Sprintf("Mismatch up/down size: up %d vs. down %d", len(qp.up), len(qp.down)))
}
for k := range qp.up {
if _, ok := qp.down[k]; !ok {
return nil, errors.New(errors.MigrationIntegrity, op, fmt.Sprintf("Up key %d doesn't exist in down %v", k, qp.down))
}
qp.versions = append(qp.versions, k)
}
sort.Ints(qp.versions)
for len(qp.versions) > 0 && qp.versions[0] <= curVer {
qp.versions = qp.versions[1:]
}
return &qp, nil
}
func (q *statementProvider) Next() bool {
q.pos++
return len(q.versions) > q.pos
}
func (q *statementProvider) Version() int {
if q.pos < 0 || q.pos >= len(q.versions) {
return -1
}
return q.versions[q.pos]
}
// ReadUp reads the current up migration
func (q *statementProvider) ReadUp() []byte {
if q.pos < 0 || q.pos >= len(q.versions) {
return nil
}
return q.up[q.versions[q.pos]]
}

@ -0,0 +1,87 @@
package schema
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestStatementProvider(t *testing.T) {
testDialect := "test"
migrationStates[testDialect] = migrationState{
binarySchemaVersion: 5,
upMigrations: map[int][]byte{
1: []byte("one"),
2: []byte("two"),
3: []byte("three"),
},
downMigrations: map[int][]byte{
1: []byte("down one"),
2: []byte("down two"),
3: []byte("down three"),
},
}
st, err := newStatementProvider(testDialect, 1)
assert.NoError(t, err)
assert.Equal(t, -1, st.Version())
assert.Equal(t, []byte(nil), st.ReadUp())
assert.True(t, st.Next())
assert.Equal(t, 2, st.Version())
assert.Equal(t, []byte("two"), st.ReadUp())
assert.True(t, st.Next())
assert.Equal(t, 3, st.Version())
assert.Equal(t, []byte("three"), st.ReadUp())
assert.False(t, st.Next())
assert.Equal(t, -1, st.Version())
assert.Equal(t, []byte(nil), st.ReadUp())
assert.False(t, st.Next())
assert.Equal(t, -1, st.Version())
assert.Equal(t, []byte(nil), st.ReadUp())
st, err = newStatementProvider("unknown_dialect", nilVersion)
assert.NoError(t, err)
assert.False(t, st.Next())
}
func TestStatementProvider_error(t *testing.T) {
cases := []struct {
name string
in migrationState
}{
{
name: "mismatchLength",
in: migrationState{
binarySchemaVersion: 5,
upMigrations: map[int][]byte{
1: []byte("one"),
},
downMigrations: map[int][]byte{},
},
},
{
name: "mismatchVersions",
in: migrationState{
binarySchemaVersion: 5,
upMigrations: map[int][]byte{
1: []byte("one"),
},
downMigrations: map[int][]byte{
2: []byte("two"),
},
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
migrationStates[tc.name] = tc.in
defer delete(migrationStates, tc.name)
_, err := newStatementProvider(tc.name, -1)
assert.Error(t, err)
})
}
}

@ -10,6 +10,7 @@ import (
_ "github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/hashicorp/boundary/internal/oplog/store"
wrapping "github.com/hashicorp/go-kms-wrapping"
@ -23,6 +24,7 @@ func TestSetup(t *testing.T, dialect string, opt ...TestOption) (*gorm.DB, strin
var cleanup func() error
var url string
var err error
ctx := context.Background()
opts := getTestOpts(opt...)
@ -36,13 +38,14 @@ func TestSetup(t *testing.T, dialect string, opt ...TestOption) (*gorm.DB, strin
assert.NoError(t, cleanup(), "Got error cleaning up db in docker.")
})
default:
cleanup = func() error { return nil }
url = opts.withTestDatabaseUrl
}
_, err = InitStore(dialect, cleanup, url)
_, err = schema.InitStore(ctx, dialect, url)
if err != nil {
t.Fatal(err)
t.Fatalf("Couldn't init store on existing db: %v", err)
}
db, err := gorm.Open(dialect, url)
if err != nil {
t.Fatal(err)

@ -30,7 +30,7 @@ func startDbInDockerSupported(dialect string) (cleanup func() error, retURL, con
resource, err = pool.Run("postgres", "12", []string{"POSTGRES_PASSWORD=password", "POSTGRES_DB=boundary"})
url = "postgres://postgres:password@localhost:%s?sslmode=disable"
if err == nil {
url = fmt.Sprintf("postgres://postgres:password@%s?sslmode=disable", resource.GetHostPort("5432/tcp"))
url = fmt.Sprintf("postgres://postgres:password@%s/boundary?sslmode=disable", resource.GetHostPort("5432/tcp"))
}
default:
panic(fmt.Sprintf("unknown dialect %q", dialect))

@ -56,10 +56,14 @@ const (
NotNull Code = 1001 // NotNull represents a value must not be null error
NotUnique Code = 1002 // NotUnique represents a value must be unique error
NotSpecificIntegrity Code = 1003 // NotSpecificIntegrity represents an integrity error that has no specific domain error code
MissingTable Code = 1004 // Missing table represents an undefined table error
MissingTable Code = 1004 // MissingTable represents an undefined table error
RecordNotFound Code = 1100 // RecordNotFound represents that a record/row was not found matching the criteria
MultipleRecords Code = 1101 // MultipleRecords represents that multiple records/rows were found matching the criteria
ColumnNotFound Code = 1102 // ColumnNotFound represent that a column was not found in the underlying db
MaxRetries Code = 1103 // MaxRetries represent that a db Tx hit max retires allowed
Exception Code = 1104 // Exception represent that an underlying db exception was raised
// Migration setup errors are codes 2000-3000
MigrationIntegrity Code = 2000 // MigrationIntegrity represents an error with the generated migration related code
MigrationLock Code = 2001 // MigrationLock represents an error related to locking of the DB
)

@ -162,6 +162,16 @@ func TestCode_Both_String_Info(t *testing.T) {
c: MissingTable,
want: MissingTable,
},
{
name: "MigrationIntegrity",
c: MigrationIntegrity,
want: MigrationIntegrity,
},
{
name: "MigrationLock",
c: MigrationLock,
want: MigrationLock,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

@ -124,4 +124,12 @@ var errorCodeInfo = map[Code]Info{
Message: "too many retries",
Kind: Transaction,
},
MigrationIntegrity: {
Message: "migration integrity",
Kind: Integrity,
},
MigrationLock: {
Message: "bad db lock",
Kind: Integrity,
},
}

@ -1,11 +1,12 @@
package oplog
import (
"context"
"crypto/rand"
"database/sql"
"testing"
"github.com/golang-migrate/migrate/v4"
"github.com/hashicorp/boundary/internal/db/migrations"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/docker"
"github.com/hashicorp/boundary/internal/oplog/oplog_test"
wrapping "github.com/hashicorp/go-kms-wrapping"
@ -79,16 +80,12 @@ func testWrapper(t *testing.T) wrapping.Wrapper {
// testInitStore will execute the migrations needed to initialize the store for tests
func testInitStore(t *testing.T, cleanup func() error, url string) {
t.Helper()
// run migrations
source, err := migrations.NewMigrationSource("postgres")
require.NoError(t, err, "Error creating migration source")
m, err := migrate.NewWithSourceInstance("postgres", source, url)
require.NoError(t, err, "Error creating migrations")
ctx := context.Background()
dialect := "postgres"
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
if err := cleanup(); err != nil {
t.Fatalf("error cleaning up after migration failure: %v", err)
}
require.NoError(t, err, "Error running migrations")
}
d, err := sql.Open(dialect, url)
require.NoError(t, err)
sm, err := schema.NewManager(ctx, dialect, d)
require.NoError(t, err)
require.NoError(t, sm.RollForward(ctx))
}

@ -14,7 +14,7 @@ import (
"github.com/hashicorp/boundary/internal/authtoken"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/servers"
@ -324,10 +324,10 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
}
// Base server
tc.b = base.NewServer(nil)
tc.b.Command = &base.Command{
tc.b = base.NewServer(&base.Command{
Context: ctx,
ShutdownCh: make(chan struct{}),
}
})
// Get dev config, or use a provided one
var err error
@ -412,7 +412,7 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
if opts.DatabaseUrl != "" {
tc.b.DatabaseUrl = opts.DatabaseUrl
if _, err := db.InitStore("postgres", nil, tc.b.DatabaseUrl); err != nil {
if _, err := schema.InitStore(ctx, "postgres", tc.b.DatabaseUrl); err != nil {
t.Fatal(err)
}
if err := tc.b.ConnectToDatabase("postgres"); err != nil {
@ -423,7 +423,7 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
t.Fatal(err)
}
if !opts.DisableInitialLoginRoleCreation {
if _, err := tc.b.CreateInitialLoginRole(context.Background()); err != nil {
if _, err := tc.b.CreateInitialLoginRole(ctx); err != nil {
t.Fatal(err)
}
if !opts.DisableAuthMethodCreation {
@ -453,7 +453,7 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
if opts.DisableAuthMethodCreation {
createOpts = append(createOpts, base.WithSkipAuthMethodCreation())
}
if err := tc.b.CreateDevDatabase("postgres", createOpts...); err != nil {
if err := tc.b.CreateDevDatabase(ctx, "postgres", createOpts...); err != nil {
t.Fatal(err)
}
}

Loading…
Cancel
Save