mirror of https://github.com/hashicorp/boundary
Organize DB Schema Migration Code and DB Init Checks (#842)
Create Schema Manager and propagate contexts into the DB calls downstream. Add checks at controller startup and database init for correct schema version and schema migration dirty bit being set.pull/873/head
parent
c141a050e5
commit
ec6151d174
@ -1,128 +0,0 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4/source"
|
||||
"github.com/golang-migrate/migrate/v4/source/httpfs"
|
||||
)
|
||||
|
||||
// migrationDriver satisfies the remaining need of the Driver interface, since
|
||||
// the package uses PartialDriver under the hood
|
||||
type migrationDriver struct {
|
||||
dialect string
|
||||
}
|
||||
|
||||
// Open returns the given "file"
|
||||
func (m *migrationDriver) Open(name string) (http.File, error) {
|
||||
return newFakeFile(m.dialect, name)
|
||||
}
|
||||
|
||||
// NewMigrationSource creates a source.Driver using httpfs with the given dialect
|
||||
func NewMigrationSource(dialect string) (source.Driver, error) {
|
||||
switch dialect {
|
||||
case "postgres":
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown migrations dialect %s", dialect)
|
||||
}
|
||||
return httpfs.New(&migrationDriver{dialect}, "migrations")
|
||||
}
|
||||
|
||||
// fakeFile is used to satisfy the http.File interface
|
||||
type fakeFile struct {
|
||||
name string
|
||||
bytes []byte
|
||||
reader *bytes.Reader
|
||||
dialect string
|
||||
}
|
||||
|
||||
func newFakeFile(dialect string, name string) (*fakeFile, error) {
|
||||
var ff *fakeFile
|
||||
switch dialect {
|
||||
case "postgres":
|
||||
ff = postgresMigrations[name]
|
||||
}
|
||||
if ff == nil {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
ff.name = strings.TrimPrefix(name, "migrations/")
|
||||
ff.reader = bytes.NewReader(ff.bytes)
|
||||
ff.dialect = dialect
|
||||
return ff, nil
|
||||
}
|
||||
|
||||
func (f *fakeFile) Read(p []byte) (n int, err error) {
|
||||
return f.reader.Read(p)
|
||||
}
|
||||
|
||||
func (f *fakeFile) Seek(offset int64, whence int) (int64, error) {
|
||||
return f.reader.Seek(offset, whence)
|
||||
}
|
||||
|
||||
func (f *fakeFile) Close() error { return nil }
|
||||
|
||||
// Readdir returns os.FileInfo values, in sorted order, and eliding the
|
||||
// migrations "dir"
|
||||
func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) {
|
||||
// Get the right map
|
||||
var migrationsMap map[string]*fakeFile
|
||||
switch f.dialect {
|
||||
case "postgres":
|
||||
migrationsMap = postgresMigrations
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown database dialect %s", f.dialect)
|
||||
}
|
||||
|
||||
// Sort the keys. May not be necessary but feels nice.
|
||||
keys := make([]string, 0, len(migrationsMap))
|
||||
for k := range migrationsMap {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
// Create the slice of fileinfo objects to return
|
||||
ret := make([]os.FileInfo, 0, len(migrationsMap))
|
||||
for i, v := range keys {
|
||||
// We need "migrations" in the map for the initial Open call but we
|
||||
// should not return it as part of the "directory"'s "files".
|
||||
if v == "migrations" {
|
||||
continue
|
||||
}
|
||||
stat, err := migrationsMap[v].Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret = append(ret, stat)
|
||||
if count > 0 && count == i {
|
||||
break
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// Stat returns a new fakeFileInfo object with the necessary bits
|
||||
func (f *fakeFile) Stat() (os.FileInfo, error) {
|
||||
return &fakeFileInfo{
|
||||
name: f.name,
|
||||
size: int64(len(f.bytes)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// fakeFileInfo satisfies os.FileInfo but represents our fake "files"
|
||||
type fakeFileInfo struct {
|
||||
name string
|
||||
size int64
|
||||
}
|
||||
|
||||
func (f *fakeFileInfo) Name() string { return f.name }
|
||||
func (f *fakeFileInfo) Size() int64 { return f.size }
|
||||
func (f *fakeFileInfo) Mode() os.FileMode { return os.ModePerm }
|
||||
func (f *fakeFileInfo) ModTime() time.Time { return time.Now() }
|
||||
func (f *fakeFileInfo) IsDir() bool { return false }
|
||||
func (f *fakeFileInfo) Sys() interface{} { return nil }
|
||||
@ -1,177 +0,0 @@
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4/source"
|
||||
"github.com/golang-migrate/migrate/v4/source/httpfs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewMigrationSource(t *testing.T) {
|
||||
type args struct {
|
||||
dialect string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want source.Driver
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "postgres",
|
||||
args: args{dialect: "postgres"},
|
||||
want: func() source.Driver {
|
||||
d, err := httpfs.New(&migrationDriver{"postgres"}, "migrations")
|
||||
if err != nil {
|
||||
t.Errorf("NewMigrationSource() error creating httpfs = %w", err)
|
||||
}
|
||||
return d
|
||||
}(),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no-dialect",
|
||||
args: args{dialect: ""},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "bad-dialect",
|
||||
args: args{dialect: "rainbows-and-unicorns-db"},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewMigrationSource(tt.args.dialect)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewMigrationSource() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewMigrationSource() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_migrationDriver_Open(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
dialect string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid-file",
|
||||
dialect: "postgres",
|
||||
args: args{name: "migrations/01_domain_types.up.sql"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "bad-file",
|
||||
dialect: "postgres",
|
||||
args: args{name: "migrations/unicorns-and-rainbows.up.sql"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &migrationDriver{
|
||||
dialect: tt.dialect,
|
||||
}
|
||||
_, err := m.Open(tt.args.name)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("migrationDriver.Open() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_fakeFile_Read(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
ff, err := newFakeFile("postgres", "migrations/01_domain_types.up.sql")
|
||||
assert.NoError(err)
|
||||
buf := make([]byte, len(ff.bytes))
|
||||
n, err := ff.Read(buf)
|
||||
assert.NoError(err)
|
||||
assert.Equal(len(buf), n)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_fakeFile_Seek(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
ff, err := newFakeFile("postgres", "migrations/01_domain_types.up.sql")
|
||||
assert.NoError(err)
|
||||
buf := make([]byte, len(ff.bytes))
|
||||
n, err := ff.Seek(10, 0)
|
||||
assert.NoError(err)
|
||||
assert.Equal(int64(10), n)
|
||||
|
||||
n2, err := ff.Read(buf)
|
||||
assert.NoError(err)
|
||||
assert.Equal(len(ff.bytes)-10, n2)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_fakeFile_Close(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
m := &migrationDriver{
|
||||
dialect: "postgres",
|
||||
}
|
||||
f, err := m.Open("migrations/01_domain_types.up.sql")
|
||||
assert.NoError(err)
|
||||
err = f.Close()
|
||||
assert.NoError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_fakeFile_Stat(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
name := "migrations/01_domain_types.up.sql"
|
||||
ff, err := newFakeFile("postgres", name)
|
||||
assert.NoError(err)
|
||||
info, err := ff.Stat()
|
||||
assert.NoError(err)
|
||||
assert.Equal(ff.name, info.Name())
|
||||
assert.Equal(int64(len(ff.bytes)), info.Size())
|
||||
assert.Equal(os.ModePerm, info.Mode())
|
||||
assert.Equal(false, info.IsDir())
|
||||
assert.Equal(nil, info.Sys())
|
||||
})
|
||||
}
|
||||
|
||||
func Test_fakeFile_Readdir(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
name := "migrations/01_domain_types.up.sql"
|
||||
ff, err := newFakeFile("postgres", name)
|
||||
assert.NoError(err)
|
||||
info, err := ff.Readdir(0)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(info)
|
||||
|
||||
info, err = ff.Readdir(1)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(info)
|
||||
assert.Equal(1, len(info))
|
||||
|
||||
info, err = ff.Readdir(0)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(info)
|
||||
// we don't want to count "migrations", so we're len - 1
|
||||
assert.Equal(len(postgresMigrations)-1, len(info))
|
||||
})
|
||||
}
|
||||
@ -1,149 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// generate looks for migration sql in a directory for the given dialect and
|
||||
// applies the templates below to the contents of the files, building up a
|
||||
// migrations map for the dialect
|
||||
func generate(dialect string) {
|
||||
baseDir := os.Getenv("GEN_BASEPATH") + "/internal/db/migrations"
|
||||
dir, err := os.Open(fmt.Sprintf("%s/%s", baseDir, dialect))
|
||||
if err != nil {
|
||||
fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
versions, err := dir.Readdirnames(0)
|
||||
if err != nil {
|
||||
fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
outBuf := bytes.NewBuffer(nil)
|
||||
valuesBuf := bytes.NewBuffer(nil)
|
||||
|
||||
sort.Strings(versions)
|
||||
|
||||
isDev := false
|
||||
largestVer := 0
|
||||
for _, ver := range versions {
|
||||
var verVal int
|
||||
switch ver {
|
||||
case "dev":
|
||||
verVal = largestVer + 1
|
||||
default:
|
||||
if verVal, err = strconv.Atoi(ver); err != nil {
|
||||
fmt.Printf("error reading major schema version directory %q. Must be a number or 'dev'\n", ver)
|
||||
os.Exit(1)
|
||||
}
|
||||
if verVal > largestVer {
|
||||
largestVer = verVal
|
||||
}
|
||||
}
|
||||
|
||||
dir, err := os.Open(fmt.Sprintf("%s/%s/%s", baseDir, dialect, ver))
|
||||
if err != nil {
|
||||
fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
names, err := dir.Readdirnames(0)
|
||||
if err != nil {
|
||||
fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if ver == "dev" && len(names) > 0 {
|
||||
isDev = true
|
||||
}
|
||||
|
||||
sort.Strings(names)
|
||||
for _, name := range names {
|
||||
if !strings.HasSuffix(name, ".sql") {
|
||||
continue
|
||||
}
|
||||
|
||||
contents, err := ioutil.ReadFile(fmt.Sprintf("%s/%s/%s/%s", baseDir, dialect, ver, name))
|
||||
if err != nil {
|
||||
fmt.Printf("error opening file %s with dialect %s: %v", name, dialect, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
vName := name
|
||||
nameParts := strings.SplitN(name, "_", 2)
|
||||
if len(nameParts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
nameVer, err := strconv.Atoi(nameParts[0])
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to get file version from %q\n", name)
|
||||
continue
|
||||
}
|
||||
vName = fmt.Sprintf("%02d_%s", (verVal*1000)+nameVer, nameParts[1])
|
||||
|
||||
if err := migrationsValueTemplate.Execute(valuesBuf, struct {
|
||||
Name string
|
||||
Contents string
|
||||
}{
|
||||
Name: vName,
|
||||
Contents: string(contents),
|
||||
}); err != nil {
|
||||
fmt.Printf("error executing migrations value template for file %s/%s: %s", ver, name, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := migrationsTemplate.Execute(outBuf, struct {
|
||||
Type string
|
||||
Values string
|
||||
DevMigration bool
|
||||
}{
|
||||
Type: dialect,
|
||||
Values: valuesBuf.String(),
|
||||
DevMigration: isDev,
|
||||
}); err != nil {
|
||||
fmt.Printf("error executing migrations value template for dialect %s: %s", dialect, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
outFile := fmt.Sprintf("%s/%s.gen.go", baseDir, dialect)
|
||||
if err := ioutil.WriteFile(outFile, outBuf.Bytes(), 0o644); err != nil {
|
||||
fmt.Printf("error writing file %q: %v\n", outFile, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
var migrationsTemplate = template.Must(template.New("").Parse(
|
||||
`// Code generated by "make migrations"; DO NOT EDIT.
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// DevMigration is true if the database schema that would be applied by
|
||||
// InitStore would be from files in the /dev directory which indicates it would
|
||||
// not be safe to run in a non dev environment.
|
||||
var DevMigration = {{ .DevMigration }}
|
||||
|
||||
var {{ .Type }}Migrations = map[string]*fakeFile{
|
||||
"migrations": {
|
||||
name: "migrations",
|
||||
},
|
||||
{{ .Values }}
|
||||
}
|
||||
`))
|
||||
|
||||
var migrationsValueTemplate = template.Must(template.New("").Parse(
|
||||
`"migrations/{{ .Name }}": {
|
||||
name: "{{ .Name }}",
|
||||
bytes: []byte(` + "`\n{{ .Contents }}\n`" + `),
|
||||
},
|
||||
`))
|
||||
@ -0,0 +1,187 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/db/schema/postgres"
|
||||
"github.com/hashicorp/boundary/internal/errors"
|
||||
)
|
||||
|
||||
// driver provides functionality to a database.
|
||||
type driver interface {
|
||||
TrySharedLock(context.Context) error
|
||||
TryLock(context.Context) error
|
||||
Lock(context.Context) error
|
||||
Unlock(context.Context) error
|
||||
UnlockShared(context.Context) error
|
||||
Run(context.Context, io.Reader) error
|
||||
// A value of -1 indicates no version is set.
|
||||
SetVersion(context.Context, int, bool) error
|
||||
// A value of -1 indicates no version is set.
|
||||
Version(context.Context) (int, bool, error)
|
||||
}
|
||||
|
||||
// Manager provides a way to run operations and retrieve information regarding
|
||||
// the underlying boundary database schema.
|
||||
// Manager is not thread safe.
|
||||
type Manager struct {
|
||||
db *sql.DB
|
||||
driver driver
|
||||
dialect string
|
||||
}
|
||||
|
||||
// NewManager creates a new schema manager. An error is returned
|
||||
// if the provided dialect is unrecognized or if the passed in db is unreachable.
|
||||
func NewManager(ctx context.Context, dialect string, db *sql.DB) (*Manager, error) {
|
||||
const op = "schema.NewManager"
|
||||
dbM := Manager{db: db, dialect: dialect}
|
||||
switch dialect {
|
||||
case "postgres":
|
||||
var err error
|
||||
dbM.driver, err = postgres.New(ctx, db)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, op)
|
||||
}
|
||||
default:
|
||||
return nil, errors.New(errors.InvalidParameter, op, fmt.Sprintf("unknown dialect %q", dialect))
|
||||
}
|
||||
return &dbM, nil
|
||||
}
|
||||
|
||||
// State contains information regarding the current state of a boundary database's schema.
|
||||
type State struct {
|
||||
// InitializationStarted indicates if the current database has already been initialized
|
||||
// (successfully or not) at least once.
|
||||
InitializationStarted bool
|
||||
// Dirty is set to true if the database failed in a previous migration/initialization.
|
||||
Dirty bool
|
||||
// DatabaseSchemaVersion is the schema version that is currently running in the database.
|
||||
DatabaseSchemaVersion int
|
||||
// BinarySchemaVersion is the schema version which this boundary binary supports.
|
||||
BinarySchemaVersion int
|
||||
}
|
||||
|
||||
// CurrentState provides the state of the boundary schema contained in the backing database.
|
||||
func (b *Manager) CurrentState(ctx context.Context) (*State, error) {
|
||||
dbS := State{
|
||||
BinarySchemaVersion: BinarySchemaVersion(b.dialect),
|
||||
}
|
||||
v, dirty, err := b.driver.Version(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v == nilVersion {
|
||||
return &dbS, nil
|
||||
}
|
||||
dbS.InitializationStarted = true
|
||||
dbS.DatabaseSchemaVersion = v
|
||||
dbS.Dirty = dirty
|
||||
return &dbS, nil
|
||||
}
|
||||
|
||||
// SharedLock attempts to obtain a shared lock on the database. This can fail if
|
||||
// an exclusive lock is already held with the same key. An error is returned if
|
||||
// a lock was unable to be obtained.
|
||||
func (b *Manager) SharedLock(ctx context.Context) error {
|
||||
const op = "schema.(Manager).SharedLock"
|
||||
if err := b.driver.TrySharedLock(ctx); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SharedUnlock releases a shared lock on the database. If this
|
||||
// fails for whatever reason an error is returned. Unlocking a lock
|
||||
// that is not held is not an error.
|
||||
func (b *Manager) SharedUnlock(ctx context.Context) error {
|
||||
const op = "schema.(Manager).SharedUnlock"
|
||||
if err := b.driver.UnlockShared(ctx); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExclusiveLock attempts to obtain an exclusive lock on the database.
|
||||
// An error is returned if a lock was unable to be obtained.
|
||||
func (b *Manager) ExclusiveLock(ctx context.Context) error {
|
||||
const op = "schema.(Manager).ExclusiveLock"
|
||||
if err := b.driver.TryLock(ctx); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExclusiveUnlock releases a shared lock on the database. If this
|
||||
// fails for whatever reason an error is returned. Unlocking a lock
|
||||
// that is not held is not an error.
|
||||
func (b *Manager) ExclusiveUnlock(ctx context.Context) error {
|
||||
const op = "schema.(Manager).ExclusiveUnlock"
|
||||
if err := b.driver.Unlock(ctx); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RollForward updates the database schema to match the latest version known by
|
||||
// the boundary binary. An error is not returned if the database is already at
|
||||
// the most recent version.
|
||||
func (b *Manager) RollForward(ctx context.Context) error {
|
||||
const op = "schema.(Manager).RollForward"
|
||||
|
||||
// Capturing a lock that this session to the db already possesses is okay.
|
||||
if err := b.driver.Lock(ctx); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
defer func() {
|
||||
b.driver.Unlock(ctx)
|
||||
}()
|
||||
|
||||
curVersion, dirty, err := b.driver.Version(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
|
||||
if dirty {
|
||||
return errors.New(errors.NotSpecificIntegrity, op, fmt.Sprintf("schema is dirty with version %d", curVersion))
|
||||
}
|
||||
|
||||
sp, err := newStatementProvider(b.dialect, curVersion)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
return b.runMigrations(ctx, sp)
|
||||
}
|
||||
|
||||
// runMigrations passes migration queries to a database driver and manages
|
||||
// the version and dirty bit. Cancelation or deadline/timeout is managed
|
||||
// through the passed in context.
|
||||
func (b *Manager) runMigrations(ctx context.Context, qp *statementProvider) error {
|
||||
const op = "schema.(Manager).runMigrations"
|
||||
for qp.Next() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), op)
|
||||
default:
|
||||
// context is not done yet. Continue on to the next query to execute.
|
||||
}
|
||||
|
||||
// set version with dirty state
|
||||
if err := b.driver.SetVersion(ctx, qp.Version(), true); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
|
||||
if err := b.driver.Run(ctx, bytes.NewReader(qp.ReadUp())); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
|
||||
// set clean state
|
||||
if err := b.driver.SetVersion(ctx, qp.Version(), false); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,201 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/db/schema/postgres"
|
||||
"github.com/hashicorp/boundary/internal/docker"
|
||||
"github.com/hashicorp/boundary/internal/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
c, u, _, err := docker.StartDbInDocker("postgres")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, c())
|
||||
})
|
||||
d, err := sql.Open("postgres", u)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
_, err = NewManager(ctx, "postgres", d)
|
||||
require.NoError(t, err)
|
||||
_, err = NewManager(ctx, "unknown", d)
|
||||
assert.True(t, errors.Match(errors.T(errors.InvalidParameter), err))
|
||||
|
||||
d.Close()
|
||||
_, err = NewManager(ctx, "postgres", d)
|
||||
assert.True(t, errors.Match(errors.T(errors.Op("schema.NewManager")), err))
|
||||
}
|
||||
|
||||
func TestCurrentState(t *testing.T) {
|
||||
c, u, _, err := docker.StartDbInDocker("postgres")
|
||||
t.Cleanup(func() {
|
||||
if err := c(); err != nil {
|
||||
t.Fatalf("Got error at cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, c())
|
||||
})
|
||||
ctx := context.Background()
|
||||
d, err := sql.Open("postgres", u)
|
||||
require.NoError(t, err)
|
||||
|
||||
m, err := NewManager(ctx, "postgres", d)
|
||||
require.NoError(t, err)
|
||||
want := &State{
|
||||
BinarySchemaVersion: BinarySchemaVersion("postgres"),
|
||||
}
|
||||
s, err := m.CurrentState(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, s)
|
||||
|
||||
testDriver, err := postgres.New(ctx, d)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, testDriver.SetVersion(ctx, 2, true))
|
||||
|
||||
want = &State{
|
||||
InitializationStarted: true,
|
||||
BinarySchemaVersion: BinarySchemaVersion("postgres"),
|
||||
Dirty: true,
|
||||
DatabaseSchemaVersion: 2,
|
||||
}
|
||||
s, err = m.CurrentState(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, s)
|
||||
}
|
||||
|
||||
func TestRollForward(t *testing.T) {
|
||||
c, u, _, err := docker.StartDbInDocker("postgres")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, c())
|
||||
})
|
||||
d, err := sql.Open("postgres", u)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
m, err := NewManager(ctx, "postgres", d)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, m.RollForward(ctx))
|
||||
|
||||
// Now set to dirty at an early version
|
||||
testDriver, err := postgres.New(ctx, d)
|
||||
require.NoError(t, err)
|
||||
testDriver.SetVersion(ctx, 0, true)
|
||||
assert.Error(t, m.RollForward(ctx))
|
||||
}
|
||||
|
||||
func TestRollForward_NotFromFresh(t *testing.T) {
|
||||
dialect := "postgres"
|
||||
oState := migrationStates[dialect]
|
||||
|
||||
nState := createPartialMigrationState(oState, 8)
|
||||
migrationStates[dialect] = nState
|
||||
|
||||
c, u, _, err := docker.StartDbInDocker("postgres")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, c())
|
||||
})
|
||||
d, err := sql.Open(dialect, u)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Initialize the DB with only a portion of the current sql scripts.
|
||||
ctx := context.Background()
|
||||
m, err := NewManager(ctx, dialect, d)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, m.RollForward(ctx))
|
||||
|
||||
ver, dirty, err := m.driver.Version(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, nState.binarySchemaVersion, ver)
|
||||
assert.False(t, dirty)
|
||||
|
||||
// Restore the full set of sql scripts and roll the rest of the way forward.
|
||||
migrationStates[dialect] = oState
|
||||
|
||||
newM, err := NewManager(ctx, dialect, d)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, newM.RollForward(ctx))
|
||||
ver, dirty, err = newM.driver.Version(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, oState.binarySchemaVersion, ver)
|
||||
assert.False(t, dirty)
|
||||
}
|
||||
|
||||
func TestManager_ExclusiveLock(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
c, u, _, err := docker.StartDbInDocker("postgres")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, c())
|
||||
})
|
||||
d1, err := sql.Open("postgres", u)
|
||||
require.NoError(t, err)
|
||||
m1, err := NewManager(ctx, "postgres", d1)
|
||||
require.NoError(t, err)
|
||||
|
||||
d2, err := sql.Open("postgres", u)
|
||||
require.NoError(t, err)
|
||||
m2, err := NewManager(ctx, "postgres", d2)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NoError(t, m1.ExclusiveLock(ctx))
|
||||
assert.NoError(t, m1.ExclusiveLock(ctx))
|
||||
assert.Error(t, m2.ExclusiveLock(ctx))
|
||||
assert.Error(t, m2.SharedLock(ctx))
|
||||
}
|
||||
|
||||
func TestManager_SharedLock(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
c, u, _, err := docker.StartDbInDocker("postgres")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, c())
|
||||
})
|
||||
d1, err := sql.Open("postgres", u)
|
||||
require.NoError(t, err)
|
||||
m1, err := NewManager(ctx, "postgres", d1)
|
||||
require.NoError(t, err)
|
||||
|
||||
d2, err := sql.Open("postgres", u)
|
||||
require.NoError(t, err)
|
||||
m2, err := NewManager(ctx, "postgres", d2)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NoError(t, m1.SharedLock(ctx))
|
||||
assert.NoError(t, m2.SharedLock(ctx))
|
||||
assert.NoError(t, m1.SharedLock(ctx))
|
||||
assert.NoError(t, m2.SharedLock(ctx))
|
||||
|
||||
assert.Error(t, m1.ExclusiveLock(ctx))
|
||||
assert.Error(t, m2.ExclusiveLock(ctx))
|
||||
}
|
||||
|
||||
// Creates a new migrationState only with the versions <= the provided maxVer
|
||||
func createPartialMigrationState(om migrationState, maxVer int) migrationState {
|
||||
nState := migrationState{
|
||||
devMigration: om.devMigration,
|
||||
upMigrations: make(map[int][]byte),
|
||||
downMigrations: make(map[int][]byte),
|
||||
}
|
||||
for k := range om.upMigrations {
|
||||
if k > maxVer {
|
||||
// Don't store any versions past our test version.
|
||||
continue
|
||||
}
|
||||
nState.upMigrations[k] = om.upMigrations[k]
|
||||
nState.downMigrations[k] = om.downMigrations[k]
|
||||
if nState.binarySchemaVersion < k {
|
||||
nState.binarySchemaVersion = k
|
||||
}
|
||||
}
|
||||
return nState
|
||||
}
|
||||
@ -0,0 +1,27 @@
|
||||
# migrations package
|
||||
This package handles the generation of the database schema in a format that can
|
||||
be compiled into the boundary binary.
|
||||
|
||||
## Organization
|
||||
|
||||
* `./generate`: contains the makefile, code, and templates needed to generate the schema info.
|
||||
* `./postgres`: contains the versioned schema folders. The contents of these folders, except
|
||||
for `dev` should not be modified.
|
||||
|
||||
## Usage
|
||||
To regenerate the schema information into the format the boundary binary uses run
|
||||
`make migrations` or `make gen` to recreate all generated code.
|
||||
|
||||
The content of the folders under `./postgres` are compiled into the
|
||||
boundary binary and when the `boundary database init` or `boundary database migrate`
|
||||
commands are executed they are applied in order of their version.
|
||||
|
||||
The `./postgres/dev` directory contains schema files that are under development and
|
||||
are not included in a release yet and so it is the only directory where additions and
|
||||
modifications are allowed. When a boundary binary is built when this directory is not
|
||||
empty a special flag is required to run the `boundary database init` command to indicate
|
||||
the user is aware that this is a development release and running this command can
|
||||
result in a completely broken schema and dataloss.
|
||||
|
||||
When a new release is made the contents of the `dev` directory are moved into a new
|
||||
versioned directory.
|
||||
@ -0,0 +1,158 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// generate looks for migration sql in a directory for the given dialect and
|
||||
// applies the templates below to the contents of the files, building up a
|
||||
// migrations map for the dialect
|
||||
func generate(dialect string) {
|
||||
baseDir := os.Getenv("GEN_BASEPATH") + "/internal/db/schema"
|
||||
srcDir := baseDir + "/migrations"
|
||||
dir, err := os.Open(fmt.Sprintf("%s/%s", srcDir, dialect))
|
||||
if err != nil {
|
||||
fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
versions, err := dir.Readdirnames(0)
|
||||
if err != nil {
|
||||
fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
sort.Strings(versions)
|
||||
|
||||
type ContentValues struct {
|
||||
Name string
|
||||
Content string
|
||||
}
|
||||
var upContents []ContentValues
|
||||
var downContents []ContentValues
|
||||
|
||||
isDev := false
|
||||
var lRelVer, largestSchemaVersion int
|
||||
for _, ver := range versions {
|
||||
var verVal int
|
||||
switch ver {
|
||||
case "dev":
|
||||
verVal = lRelVer + 1
|
||||
default:
|
||||
v, err := strconv.Atoi(ver)
|
||||
if err != nil {
|
||||
fmt.Printf("error reading major schema version directory %q. Must be a number or 'dev'\n", ver)
|
||||
os.Exit(1)
|
||||
}
|
||||
verVal = v
|
||||
if verVal > lRelVer {
|
||||
lRelVer = verVal
|
||||
}
|
||||
}
|
||||
|
||||
dir, err := os.Open(fmt.Sprintf("%s/%s/%s", srcDir, dialect, ver))
|
||||
if err != nil {
|
||||
fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
names, err := dir.Readdirnames(0)
|
||||
if err != nil {
|
||||
fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if ver == "dev" && len(names) > 0 {
|
||||
isDev = true
|
||||
}
|
||||
|
||||
sort.Strings(names)
|
||||
for _, name := range names {
|
||||
if !strings.HasSuffix(name, ".sql") {
|
||||
continue
|
||||
}
|
||||
|
||||
contents, err := ioutil.ReadFile(fmt.Sprintf("%s/%s/%s/%s", srcDir, dialect, ver, name))
|
||||
if err != nil {
|
||||
fmt.Printf("error opening file %s with dialect %s: %v", name, dialect, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
nameParts := strings.SplitN(name, "_", 2)
|
||||
if len(nameParts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
v, err := strconv.Atoi(nameParts[0])
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to get file version from %q\n", name)
|
||||
continue
|
||||
}
|
||||
|
||||
fullV := (verVal * 1000) + v
|
||||
if fullV > largestSchemaVersion {
|
||||
largestSchemaVersion = fullV
|
||||
}
|
||||
cv := ContentValues{
|
||||
Name: fmt.Sprint(fullV),
|
||||
Content: string(contents),
|
||||
}
|
||||
switch {
|
||||
case strings.Contains(nameParts[1], ".down."):
|
||||
downContents = append(downContents, cv)
|
||||
case strings.Contains(nameParts[1], ".up."):
|
||||
upContents = append(upContents, cv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fmt.Printf("Got upcontent: %#v\n\n downcontents: %#v", upContents[:1], downContents[:1])
|
||||
|
||||
outBuf := bytes.NewBuffer(nil)
|
||||
if err := migrationsTemplate.Execute(outBuf, struct {
|
||||
Type string
|
||||
UpValues []ContentValues
|
||||
DownValues []ContentValues
|
||||
DevMigration bool
|
||||
BinarySchemaVersion int
|
||||
}{
|
||||
Type: dialect,
|
||||
UpValues: upContents,
|
||||
DownValues: downContents,
|
||||
DevMigration: isDev,
|
||||
BinarySchemaVersion: largestSchemaVersion,
|
||||
}); err != nil {
|
||||
fmt.Printf("error executing migrations value template for dialect %s: %s", dialect, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
outFile := fmt.Sprintf("%s/%s_migration.gen.go", baseDir, dialect)
|
||||
if err := ioutil.WriteFile(outFile, outBuf.Bytes(), 0o644); err != nil {
|
||||
fmt.Printf("error writing file %q: %v\n", outFile, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
var migrationsTemplate = template.Must(template.Must(template.New("Content").Parse(
|
||||
`{{ .Name }}: []byte(` + "`\n{{ .Content }}\n`" + `),
|
||||
`)).New("MainPage").Parse(`package schema
|
||||
|
||||
// Code generated by "make migrations"; DO NOT EDIT.
|
||||
|
||||
func init() {
|
||||
migrationStates["{{ .Type }}"] = migrationState{
|
||||
devMigration: {{ .DevMigration }},
|
||||
binarySchemaVersion: {{ .BinarySchemaVersion }},
|
||||
upMigrations: map[int][]byte{
|
||||
{{range .UpValues }}{{ template "Content" . }}{{end}}
|
||||
},
|
||||
downMigrations: map[int][]byte{
|
||||
{{range .DownValues }}{{ template "Content" . }}{{end}}
|
||||
},
|
||||
}
|
||||
}
|
||||
`))
|
||||
@ -0,0 +1,356 @@
|
||||
// The MIT License (MIT)
|
||||
//
|
||||
// Original Work
|
||||
// Copyright (c) 2016 Matthias Kadenbach
|
||||
// https://github.com/mattes/migrate
|
||||
//
|
||||
// Modified Work
|
||||
// Copyright (c) 2018 Dale Hui
|
||||
// https://github.com/golang-migrate/migrate
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to deal
|
||||
// in the Software without restriction, including without limitation the rights
|
||||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
// copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
// THE SOFTWARE.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/errors"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// schemaAccessLockId is a Lock key used to ensure a single boundary binary is operating
|
||||
// on a postgres server at a time. The value has no meaning and was picked randomly.
|
||||
const (
|
||||
schemaAccessLockId int64 = 3865661975
|
||||
nilVersion = -1
|
||||
)
|
||||
|
||||
var defaultMigrationsTable = "boundary_schema_version"
|
||||
|
||||
// Postgres is a driver usable by a boundary schema manager.
|
||||
type Postgres struct {
|
||||
// Locking and unlocking need to use the same connection
|
||||
conn *sql.Conn
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// New returns a postgres pointer with the provided db verified as
|
||||
// connectable and a version table being initialized.
|
||||
func New(ctx context.Context, instance *sql.DB) (*Postgres, error) {
|
||||
const op = "postgres.New"
|
||||
if err := instance.PingContext(ctx); err != nil {
|
||||
return nil, errors.Wrap(err, op)
|
||||
}
|
||||
conn, err := instance.Conn(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, op)
|
||||
}
|
||||
|
||||
px := &Postgres{
|
||||
conn: conn,
|
||||
db: instance,
|
||||
}
|
||||
|
||||
if err := px.ensureVersionTable(ctx); err != nil {
|
||||
return nil, errors.Wrap(err, op)
|
||||
}
|
||||
|
||||
return px, nil
|
||||
}
|
||||
|
||||
// TrySharedLock attempts to capture a shared lock. If it is not successful it returns an error.
|
||||
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
|
||||
func (p *Postgres) TrySharedLock(ctx context.Context) error {
|
||||
const op = "postgres.(Postgres).TrySharedLock"
|
||||
r := p.conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock_shared($1)", schemaAccessLockId)
|
||||
if r.Err() != nil {
|
||||
return errors.Wrap(r.Err(), op)
|
||||
}
|
||||
var gotLock bool
|
||||
if err := r.Scan(&gotLock); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
if !gotLock {
|
||||
return errors.New(errors.MigrationLock, op, "Lock failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TryLock attempts to capture an exclusive lock. If it is not successful it returns an error.
|
||||
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
|
||||
func (p *Postgres) TryLock(ctx context.Context) error {
|
||||
const op = "postgres.(Postgres).TryLock"
|
||||
|
||||
r := p.conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", schemaAccessLockId)
|
||||
if r.Err() != nil {
|
||||
return errors.Wrap(r.Err(), op)
|
||||
}
|
||||
var gotLock bool
|
||||
if err := r.Scan(&gotLock); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
if !gotLock {
|
||||
return errors.New(errors.MigrationLock, op, "Lock failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lock calls pg_advisory_lock with the provided context and returns an error
|
||||
// if we were unable to get the lock before the context cancels.
|
||||
func (p *Postgres) Lock(ctx context.Context) error {
|
||||
const op = "postgres.(Postgres).Lock"
|
||||
|
||||
// This will wait indefinitely until the Lock can be acquired.
|
||||
query := `SELECT pg_advisory_lock($1)`
|
||||
if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlock calls pg_advisory_unlock and returns an error if we were unable to
|
||||
// release the lock before the context cancels.
|
||||
func (p *Postgres) Unlock(ctx context.Context) error {
|
||||
const op = "postgres.(Postgres).Unlock"
|
||||
|
||||
query := `SELECT pg_advisory_unlock($1)`
|
||||
if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnlockShared calls pg_advisory_unlock_shared and returns an error if we were unable to
|
||||
// release the lock before the context cancels.
|
||||
func (p *Postgres) UnlockShared(ctx context.Context) error {
|
||||
const op = "postgres.(Postgres).UnlockShared"
|
||||
query := `SELECT pg_advisory_unlock_shared($1)`
|
||||
if _, err := p.conn.ExecContext(ctx, query, schemaAccessLockId); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Executes the sql provided in the passed in io.Reader. The contents of the reader must
|
||||
// fit in memory as the full content is read into a string before being passed to the
|
||||
// backing database.
|
||||
func (p *Postgres) Run(ctx context.Context, migration io.Reader) error {
|
||||
const op = "postgres.(Postgres).Run"
|
||||
migr, err := ioutil.ReadAll(migration)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
// Run migration
|
||||
query := string(migr)
|
||||
if _, err := p.conn.ExecContext(ctx, query); err != nil {
|
||||
if pgErr, ok := err.(*pq.Error); ok {
|
||||
var line uint
|
||||
var col uint
|
||||
var lineColOK bool
|
||||
if pgErr.Position != "" {
|
||||
if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
|
||||
line, col, lineColOK = computeLineFromPos(query, int(pos))
|
||||
}
|
||||
}
|
||||
message := fmt.Sprintf("migration failed")
|
||||
if lineColOK {
|
||||
message = fmt.Sprintf("%s (column %d)", message, col)
|
||||
}
|
||||
if pgErr.Detail != "" {
|
||||
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
|
||||
}
|
||||
message = fmt.Sprintf("%s, on line %v: %s", message, line, migr)
|
||||
return errors.Wrap(err, op, errors.WithMsg(message))
|
||||
}
|
||||
return errors.Wrap(err, op, errors.WithMsg(fmt.Sprintf("migration failed: %s", migr)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
|
||||
// replace crlf with lf
|
||||
s = strings.Replace(s, "\r\n", "\n", -1)
|
||||
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
|
||||
runes := []rune(s)
|
||||
if pos > len(runes) {
|
||||
return 0, 0, false
|
||||
}
|
||||
sel := runes[:pos]
|
||||
line = uint(runesCount(sel, newLine) + 1)
|
||||
col = uint(pos - 1 - runesLastIndex(sel, newLine))
|
||||
return line, col, true
|
||||
}
|
||||
|
||||
const newLine = '\n'
|
||||
|
||||
func runesCount(input []rune, target rune) int {
|
||||
var count int
|
||||
for _, r := range input {
|
||||
if r == target {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func runesLastIndex(input []rune, target rune) int {
|
||||
for i := len(input) - 1; i >= 0; i-- {
|
||||
if input[i] == target {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// SetVersion sets the version number, and whether the database is in a dirty state.
|
||||
// A version value of -1 indicates no version is set.
|
||||
func (p *Postgres) SetVersion(ctx context.Context, version int, dirty bool) error {
|
||||
const op = "postgres.(Postgres).SetVersion"
|
||||
tx, err := p.conn.BeginTx(ctx, &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := `TRUNCATE ` + pq.QuoteIdentifier(defaultMigrationsTable)
|
||||
if _, err := tx.ExecContext(ctx, query); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
|
||||
// Also re-write the schema Version for nil dirty versions to prevent
|
||||
// empty schema Version for failed down migration on the first migration
|
||||
// See: https://github.com/golang-migrate/migrate/issues/330
|
||||
if version >= 0 || (version == nilVersion && dirty) {
|
||||
query = `INSERT INTO ` + pq.QuoteIdentifier(defaultMigrationsTable) +
|
||||
` (Version, dirty) VALUES ($1, $2)`
|
||||
if _, err := tx.ExecContext(ctx, query, version, dirty); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Version returns the version, if the database is currently in a dirty state, and any error.
|
||||
// A version value of -1 indicates no version is set.
|
||||
func (p *Postgres) Version(ctx context.Context) (version int, dirty bool, err error) {
|
||||
const op = "postgres.(Postgres).Version"
|
||||
query := `SELECT Version, dirty FROM ` + pq.QuoteIdentifier(defaultMigrationsTable) + ` LIMIT 1`
|
||||
err = p.conn.QueryRowContext(ctx, query).Scan(&version, &dirty)
|
||||
switch {
|
||||
case err == sql.ErrNoRows:
|
||||
return nilVersion, false, nil
|
||||
|
||||
case err != nil:
|
||||
if e, ok := err.(*pq.Error); ok {
|
||||
if e.Code.Name() == "undefined_table" {
|
||||
return nilVersion, false, nil
|
||||
}
|
||||
}
|
||||
return 0, false, errors.Wrap(err, op)
|
||||
|
||||
default:
|
||||
return version, dirty, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Postgres) drop(ctx context.Context) (err error) {
|
||||
const op = "postgres.(Postgres).drop"
|
||||
// select all tables in current schema
|
||||
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
|
||||
tables, err := p.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := tables.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
err = errors.Wrap(err, op)
|
||||
}
|
||||
}()
|
||||
|
||||
// delete one table after another
|
||||
tableNames := make([]string, 0)
|
||||
for tables.Next() {
|
||||
var tableName string
|
||||
if err := tables.Scan(&tableName); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
if len(tableName) > 0 {
|
||||
tableNames = append(tableNames, tableName)
|
||||
}
|
||||
}
|
||||
if err := tables.Err(); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
|
||||
if len(tableNames) > 0 {
|
||||
// delete one by one ...
|
||||
for _, t := range tableNames {
|
||||
query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
|
||||
if _, err := p.conn.ExecContext(ctx, query); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureVersionTable checks if versions table exists and, if not, creates it.
|
||||
// Note that this function locks the database, which deviates from the usual
|
||||
// convention of "caller locks" in the postgres type.
|
||||
func (p *Postgres) ensureVersionTable(ctx context.Context) (err error) {
|
||||
const op = "postgres.(Postgres).ensureVersionTable"
|
||||
if err = p.Lock(ctx); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := p.Unlock(ctx); e != nil {
|
||||
err = multierror.Append(err, e)
|
||||
}
|
||||
}()
|
||||
|
||||
query := `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(defaultMigrationsTable) + ` (Version bigint primary key, dirty boolean not null)`
|
||||
if _, err = p.conn.ExecContext(ctx, query); err != nil {
|
||||
return errors.Wrap(err, op)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,433 @@
|
||||
// The MIT License (MIT)
|
||||
//
|
||||
// Original Work
|
||||
// Copyright (c) 2016 Matthias Kadenbach
|
||||
// https://github.com/mattes/migrate
|
||||
//
|
||||
// Modified Work
|
||||
// Copyright (c) 2018 Dale Hui
|
||||
// https://github.com/golang-migrate/migrate
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to deal
|
||||
// in the Software without restriction, including without limitation the rights
|
||||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
// copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
// THE SOFTWARE.
|
||||
|
||||
package postgres
|
||||
|
||||
// error codes https://github.com/lib/pq/blob/master/error.go
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
sqldriver "database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/dhui/dktest"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
pgPassword = "postgres"
|
||||
)
|
||||
|
||||
var (
|
||||
opts = dktest.Options{
|
||||
Env: map[string]string{"POSTGRES_PASSWORD": pgPassword},
|
||||
PortRequired: true, ReadyFunc: isReady,
|
||||
}
|
||||
// Supported versions: https://www.postgresql.org/support/versioning/
|
||||
specs = []dktesting.ContainerSpec{
|
||||
{ImageName: "postgres:12", Options: opts},
|
||||
}
|
||||
)
|
||||
|
||||
func pgConnectionString(host, port string) string {
|
||||
return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?sslmode=disable", pgPassword, host, port)
|
||||
}
|
||||
|
||||
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", pgConnectionString(ip, port))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Println("close error:", err)
|
||||
}
|
||||
}()
|
||||
if err = db.PingContext(ctx); err != nil {
|
||||
switch err {
|
||||
case sqldriver.ErrBadConn, io.EOF:
|
||||
return false
|
||||
default:
|
||||
log.Println(err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func TestDbStuff(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ctx := context.Background()
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
d, err := open(t, ctx, addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.close(t); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
test(t, d, []byte("SELECT 1"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestVersion_NoVersionTable(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ctx := context.Background()
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
d, err := open(t, ctx, addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.close(t); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
// Drop the version table so calls to Version don't rely on that
|
||||
d.drop(ctx)
|
||||
|
||||
v, dirt, err := d.Version(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, v, nilVersion)
|
||||
assert.False(t, dirt)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiStatement(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ctx := context.Background()
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
d, err := open(t, ctx, addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.close(t); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
if err := d.Run(ctx, strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
|
||||
t.Fatalf("expected err to be nil, got %v", err)
|
||||
}
|
||||
|
||||
// make sure second table exists
|
||||
var exists bool
|
||||
if err := d.conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table bar to exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithSchema(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ctx := context.Background()
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
d, err := open(t, ctx, addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.close(t); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create foobar schema
|
||||
if err := d.Run(ctx, strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.SetVersion(ctx, 1, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// re-connect using that schema
|
||||
d2, err := open(t, ctx, fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&search_path=foobar",
|
||||
pgPassword, ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d2.close(t); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
version, _, err := d2.Version(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != nilVersion {
|
||||
t.Fatal("expected NilVersion")
|
||||
}
|
||||
|
||||
// now update Version and compare
|
||||
if err := d2.SetVersion(ctx, 2, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
version, _, err = d2.Version(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != 2 {
|
||||
t.Fatal("expected Version 2")
|
||||
}
|
||||
|
||||
// meanwhile, the public schema still has the other Version
|
||||
version, _, err = d.Version(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != 1 {
|
||||
t.Fatal("expected Version 2")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostgres_Lock(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ctx := context.Background()
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
ps, err := open(t, ctx, addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
test(t, ps, []byte("SELECT 1"))
|
||||
|
||||
err = ps.Lock(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Unlock(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Lock(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Unlock(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithInstance_Concurrent(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The number of concurrent processes running New
|
||||
const concurrency = 30
|
||||
|
||||
// We can instantiate a single database handle because it is
|
||||
// actually a connection pool, and so, each of the below go
|
||||
// routines will have a high probability of using a separate
|
||||
// connection, which is something we want to exercise.
|
||||
db, err := sql.Open("postgres", pgConnectionString(ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
db.SetMaxIdleConns(concurrency)
|
||||
db.SetMaxOpenConns(concurrency)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
wg.Add(concurrency)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
_, err := New(context.Background(), db)
|
||||
if err != nil {
|
||||
t.Errorf("process %d error: %s", i, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRun_Error(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ctx := context.Background()
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p, err := open(t, ctx, addr)
|
||||
if err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, p.close(t))
|
||||
})
|
||||
|
||||
err = p.Run(ctx, bytes.NewReader([]byte("SELECT *\nFROM foo")))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_computeLineFromPos(t *testing.T) {
|
||||
testcases := []struct {
|
||||
pos int
|
||||
wantLine uint
|
||||
wantCol uint
|
||||
input string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
|
||||
},
|
||||
{
|
||||
16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
|
||||
},
|
||||
{
|
||||
25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
|
||||
},
|
||||
{
|
||||
27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
|
||||
},
|
||||
{
|
||||
10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
|
||||
},
|
||||
{
|
||||
11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
|
||||
},
|
||||
{
|
||||
17, 2, 8, "SELECT *\nFROM foo", true, // last character
|
||||
},
|
||||
{
|
||||
18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
|
||||
},
|
||||
}
|
||||
for i, tc := range testcases {
|
||||
t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
|
||||
run := func(crlf bool, nonASCII bool) {
|
||||
var name string
|
||||
if crlf {
|
||||
name = "crlf"
|
||||
} else {
|
||||
name = "lf"
|
||||
}
|
||||
if nonASCII {
|
||||
name += "-nonascii"
|
||||
} else {
|
||||
name += "-ascii"
|
||||
}
|
||||
t.Run(name, func(t *testing.T) {
|
||||
input := tc.input
|
||||
if crlf {
|
||||
input = strings.Replace(input, "\n", "\r\n", -1)
|
||||
}
|
||||
if nonASCII {
|
||||
input = strings.Replace(input, "FROM", "FRÖM", -1)
|
||||
}
|
||||
gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
|
||||
|
||||
if tc.wantOk {
|
||||
t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
|
||||
}
|
||||
|
||||
if gotOK != tc.wantOk {
|
||||
t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
|
||||
}
|
||||
if gotLine != tc.wantLine {
|
||||
t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
|
||||
}
|
||||
if gotCol != tc.wantCol {
|
||||
t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
|
||||
}
|
||||
})
|
||||
}
|
||||
run(false, false)
|
||||
run(true, false)
|
||||
run(false, true)
|
||||
run(true, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,169 @@
|
||||
// The MIT License (MIT)
|
||||
//
|
||||
// Original Work
|
||||
// Copyright (c) 2016 Matthias Kadenbach
|
||||
// https://github.com/mattes/migrate
|
||||
//
|
||||
// Modified Work
|
||||
// Copyright (c) 2018 Dale Hui
|
||||
// https://github.com/golang-migrate/migrate
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to deal
|
||||
// in the Software without restriction, including without limitation the rights
|
||||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
// copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
// THE SOFTWARE.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// test runs tests against database implementations.
|
||||
func test(t *testing.T, d *Postgres, migration []byte) {
|
||||
if migration == nil {
|
||||
t.Fatal("test must provide migration reader")
|
||||
}
|
||||
|
||||
testNilVersion(t, d) // test first
|
||||
testLockAndUnlock(t, d)
|
||||
testRun(t, d, bytes.NewReader(migration))
|
||||
testSetVersion(t, d) // also tests Version()
|
||||
// drop breaks the driver, so test it last.
|
||||
testDrop(t, d)
|
||||
}
|
||||
|
||||
func testNilVersion(t *testing.T, d *Postgres) {
|
||||
ctx := context.Background()
|
||||
v, _, err := d.Version(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if v != database.NilVersion {
|
||||
t.Fatalf("Version: expected Version to be NilVersion (-1), got %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func testLockAndUnlock(t *testing.T, d *Postgres) {
|
||||
ctx := context.Background()
|
||||
|
||||
ctx, _ = context.WithTimeout(ctx, 15*time.Second)
|
||||
|
||||
// locking twice is ok, no error
|
||||
if err := d.Lock(ctx); err != nil {
|
||||
t.Fatalf("got error, expected none: %v", err)
|
||||
}
|
||||
if err := d.Lock(ctx); err != nil {
|
||||
t.Fatalf("got error, expected none: %v", err)
|
||||
}
|
||||
|
||||
// Unlock
|
||||
if err := d.Unlock(ctx); err != nil {
|
||||
t.Fatalf("error unlocking: %v", err)
|
||||
}
|
||||
|
||||
// try to Lock
|
||||
if err := d.Lock(ctx); err != nil {
|
||||
t.Fatalf("got error, expected none: %v", err)
|
||||
}
|
||||
if err := d.Unlock(ctx); err != nil {
|
||||
t.Fatalf("got error, expected none: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testRun(t *testing.T, d *Postgres, migration io.Reader) {
|
||||
ctx := context.Background()
|
||||
if migration == nil {
|
||||
t.Fatal("migration can't be nil")
|
||||
}
|
||||
|
||||
if err := d.Run(ctx, migration); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func testDrop(t *testing.T, d *Postgres) {
|
||||
ctx := context.Background()
|
||||
if err := d.drop(ctx); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func testSetVersion(t *testing.T, d *Postgres) {
|
||||
ctx := context.Background()
|
||||
// nolint:maligned
|
||||
testCases := []struct {
|
||||
name string
|
||||
version int
|
||||
dirty bool
|
||||
expectedErr error
|
||||
expectedReadErr error
|
||||
expectedVersion int
|
||||
expectedDirty bool
|
||||
}{
|
||||
{name: "set 1 dirty", version: 1, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: 1, expectedDirty: true},
|
||||
{name: "re-set 1 dirty", version: 1, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: 1, expectedDirty: true},
|
||||
{name: "set 2 clean", version: 2, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: 2, expectedDirty: false},
|
||||
{name: "re-set 2 clean", version: 2, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: 2, expectedDirty: false},
|
||||
{name: "last migration dirty", version: database.NilVersion, dirty: true, expectedErr: nil, expectedReadErr: nil, expectedVersion: database.NilVersion, expectedDirty: true},
|
||||
{name: "last migration clean", version: database.NilVersion, dirty: false, expectedErr: nil, expectedReadErr: nil, expectedVersion: database.NilVersion, expectedDirty: false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := d.SetVersion(ctx, tc.version, tc.dirty)
|
||||
if err != tc.expectedErr {
|
||||
t.Fatal("Got unexpected error:", err, "!=", tc.expectedErr)
|
||||
}
|
||||
v, dirty, readErr := d.Version(ctx)
|
||||
if readErr != tc.expectedReadErr {
|
||||
t.Fatal("Got unexpected error:", readErr, "!=", tc.expectedReadErr)
|
||||
}
|
||||
if v != tc.expectedVersion {
|
||||
t.Error("Got unexpected Version:", v, "!=", tc.expectedVersion)
|
||||
}
|
||||
if dirty != tc.expectedDirty {
|
||||
t.Error("Got unexpected dirty value:", dirty, "!=", tc.dirty)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func open(t *testing.T, ctx context.Context, u string) (*Postgres, error) {
|
||||
t.Helper()
|
||||
db, err := sql.Open("postgres", u)
|
||||
require.NoError(t, err)
|
||||
|
||||
px, err := New(ctx, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
return px, nil
|
||||
}
|
||||
|
||||
func (p *Postgres) close(t *testing.T) error {
|
||||
t.Helper()
|
||||
require.NoError(t, p.conn.Close())
|
||||
require.NoError(t, p.db.Close())
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,40 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/errors"
|
||||
)
|
||||
|
||||
// InitStore executes the migrations needed to initialize the store. It
|
||||
// returns true if migrations actually ran; false if the database is already current.
|
||||
func InitStore(ctx context.Context, dialect string, url string) (bool, error) {
|
||||
const op = "schema.InitStore"
|
||||
d, err := sql.Open(dialect, url)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, op)
|
||||
}
|
||||
|
||||
sMan, err := NewManager(ctx, dialect, d)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, op)
|
||||
}
|
||||
|
||||
st, err := sMan.CurrentState(ctx)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, op)
|
||||
}
|
||||
if st.Dirty {
|
||||
return false, errors.New(errors.MigrationIntegrity, op, "db marked dirty")
|
||||
}
|
||||
|
||||
if st.InitializationStarted && st.DatabaseSchemaVersion == st.BinarySchemaVersion {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := sMan.RollForward(ctx); err != nil {
|
||||
return false, errors.Wrap(err, op)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
@ -0,0 +1,65 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/docker"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInitStore(t *testing.T) {
|
||||
dialect := "postgres"
|
||||
ctx := context.Background()
|
||||
|
||||
c, u, _, err := docker.StartDbInDocker(dialect)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, c())
|
||||
})
|
||||
|
||||
// Set the possible migration state to be only part of the full migration
|
||||
oState := migrationStates[dialect]
|
||||
nState := createPartialMigrationState(oState, 8)
|
||||
migrationStates[dialect] = nState
|
||||
|
||||
ran, err := InitStore(ctx, dialect, u)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, ran)
|
||||
ran, err = InitStore(ctx, dialect, u)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, ran)
|
||||
|
||||
// Reset the possible migration state to contain everything
|
||||
migrationStates[dialect] = oState
|
||||
|
||||
ran, err = InitStore(ctx, dialect, u)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, ran)
|
||||
ran, err = InitStore(ctx, dialect, u)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, ran)
|
||||
}
|
||||
|
||||
func TestInitStore_Dirty(t *testing.T) {
|
||||
dialect := "postgres"
|
||||
ctx := context.Background()
|
||||
|
||||
c, u, _, err := docker.StartDbInDocker(dialect)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, c())
|
||||
})
|
||||
|
||||
// Mark the db as dirty indicating a previously run failed migration
|
||||
db, err := sql.Open(dialect, u)
|
||||
require.NoError(t, err)
|
||||
m, err := NewManager(ctx, dialect, db)
|
||||
m.driver.SetVersion(ctx, -1, true)
|
||||
|
||||
b, err := InitStore(ctx, dialect, u)
|
||||
assert.Error(t, err)
|
||||
assert.False(t, b)
|
||||
}
|
||||
@ -0,0 +1,54 @@
|
||||
package schema
|
||||
|
||||
const nilVersion = -1
|
||||
|
||||
// migrationState is meant to be populated by the generated migration code and
|
||||
// contains the internal representation of a schema in the current binary.
|
||||
type migrationState struct {
|
||||
// devMigration is true if the database schema that would be applied by
|
||||
// InitStore would be from files in the /dev directory which indicates it would
|
||||
// not be safe to run in a non dev environment.
|
||||
devMigration bool
|
||||
|
||||
// binarySchemaVersion provides the database schema version supported by
|
||||
// this binary.
|
||||
binarySchemaVersion int
|
||||
|
||||
upMigrations map[int][]byte
|
||||
downMigrations map[int][]byte
|
||||
}
|
||||
|
||||
// migrationStates is populated by the generated migration code with the key being the dialect.
|
||||
var migrationStates = make(map[string]migrationState)
|
||||
|
||||
func getUpMigration(dialect string) map[int][]byte {
|
||||
ms, ok := migrationStates[dialect]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return ms.upMigrations
|
||||
}
|
||||
|
||||
func getDownMigration(dialect string) map[int][]byte {
|
||||
ms, ok := migrationStates[dialect]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return ms.downMigrations
|
||||
}
|
||||
|
||||
// DevMigration returns true iff the provided dialect has changes which are still in development.
|
||||
func DevMigration(dialect string) bool {
|
||||
ms, ok := migrationStates[dialect]
|
||||
return ok && ms.devMigration
|
||||
}
|
||||
|
||||
// BinarySchemaVersion provides the schema version that this binary supports for the provided dialect.
|
||||
// If the binary doesn't support this dialect -1 is returned.
|
||||
func BinarySchemaVersion(dialect string) int {
|
||||
ms, ok := migrationStates[dialect]
|
||||
if !ok {
|
||||
return nilVersion
|
||||
}
|
||||
return ms.binarySchemaVersion
|
||||
}
|
||||
@ -0,0 +1,23 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBinarySchemaVersion(t *testing.T) {
|
||||
dialect := "test_binaryschemaversion"
|
||||
migrationStates[dialect] = migrationState{binarySchemaVersion: 3}
|
||||
assert.Equal(t, 3, BinarySchemaVersion(dialect))
|
||||
assert.Equal(t, nilVersion, BinarySchemaVersion("unknown_dialect"))
|
||||
}
|
||||
|
||||
func TestDevMigration(t *testing.T) {
|
||||
dialect := "test_devmigrations"
|
||||
migrationStates[dialect] = migrationState{devMigration: true}
|
||||
assert.True(t, DevMigration(dialect))
|
||||
migrationStates[dialect] = migrationState{devMigration: false}
|
||||
assert.False(t, DevMigration(dialect))
|
||||
assert.False(t, DevMigration("unknown_dialect"))
|
||||
}
|
||||
@ -0,0 +1,59 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/errors"
|
||||
)
|
||||
|
||||
// statementProvider provides the migration statements in order.
|
||||
// Next should be called prior to calling Version() or ReadUp() or sentinel
|
||||
// values (-1 and nil) will be returned.
|
||||
type statementProvider struct {
|
||||
pos int
|
||||
versions []int
|
||||
up, down map[int][]byte
|
||||
}
|
||||
|
||||
func newStatementProvider(dialect string, curVer int) (*statementProvider, error) {
|
||||
op := errors.Op("schema.newStatementProvider")
|
||||
qp := statementProvider{pos: -1}
|
||||
qp.up, qp.down = getUpMigration(dialect), getDownMigration(dialect)
|
||||
if len(qp.up) != len(qp.down) {
|
||||
return nil, errors.New(errors.MigrationIntegrity, op, fmt.Sprintf("Mismatch up/down size: up %d vs. down %d", len(qp.up), len(qp.down)))
|
||||
}
|
||||
for k := range qp.up {
|
||||
if _, ok := qp.down[k]; !ok {
|
||||
return nil, errors.New(errors.MigrationIntegrity, op, fmt.Sprintf("Up key %d doesn't exist in down %v", k, qp.down))
|
||||
}
|
||||
qp.versions = append(qp.versions, k)
|
||||
}
|
||||
sort.Ints(qp.versions)
|
||||
|
||||
for len(qp.versions) > 0 && qp.versions[0] <= curVer {
|
||||
qp.versions = qp.versions[1:]
|
||||
}
|
||||
|
||||
return &qp, nil
|
||||
}
|
||||
|
||||
func (q *statementProvider) Next() bool {
|
||||
q.pos++
|
||||
return len(q.versions) > q.pos
|
||||
}
|
||||
|
||||
func (q *statementProvider) Version() int {
|
||||
if q.pos < 0 || q.pos >= len(q.versions) {
|
||||
return -1
|
||||
}
|
||||
return q.versions[q.pos]
|
||||
}
|
||||
|
||||
// ReadUp reads the current up migration
|
||||
func (q *statementProvider) ReadUp() []byte {
|
||||
if q.pos < 0 || q.pos >= len(q.versions) {
|
||||
return nil
|
||||
}
|
||||
return q.up[q.versions[q.pos]]
|
||||
}
|
||||
@ -0,0 +1,87 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStatementProvider(t *testing.T) {
|
||||
testDialect := "test"
|
||||
migrationStates[testDialect] = migrationState{
|
||||
binarySchemaVersion: 5,
|
||||
upMigrations: map[int][]byte{
|
||||
1: []byte("one"),
|
||||
2: []byte("two"),
|
||||
3: []byte("three"),
|
||||
},
|
||||
downMigrations: map[int][]byte{
|
||||
1: []byte("down one"),
|
||||
2: []byte("down two"),
|
||||
3: []byte("down three"),
|
||||
},
|
||||
}
|
||||
|
||||
st, err := newStatementProvider(testDialect, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, -1, st.Version())
|
||||
assert.Equal(t, []byte(nil), st.ReadUp())
|
||||
|
||||
assert.True(t, st.Next())
|
||||
assert.Equal(t, 2, st.Version())
|
||||
assert.Equal(t, []byte("two"), st.ReadUp())
|
||||
|
||||
assert.True(t, st.Next())
|
||||
assert.Equal(t, 3, st.Version())
|
||||
assert.Equal(t, []byte("three"), st.ReadUp())
|
||||
|
||||
assert.False(t, st.Next())
|
||||
assert.Equal(t, -1, st.Version())
|
||||
assert.Equal(t, []byte(nil), st.ReadUp())
|
||||
|
||||
assert.False(t, st.Next())
|
||||
assert.Equal(t, -1, st.Version())
|
||||
assert.Equal(t, []byte(nil), st.ReadUp())
|
||||
|
||||
st, err = newStatementProvider("unknown_dialect", nilVersion)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, st.Next())
|
||||
}
|
||||
|
||||
func TestStatementProvider_error(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in migrationState
|
||||
}{
|
||||
{
|
||||
name: "mismatchLength",
|
||||
in: migrationState{
|
||||
binarySchemaVersion: 5,
|
||||
upMigrations: map[int][]byte{
|
||||
1: []byte("one"),
|
||||
},
|
||||
downMigrations: map[int][]byte{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mismatchVersions",
|
||||
in: migrationState{
|
||||
binarySchemaVersion: 5,
|
||||
upMigrations: map[int][]byte{
|
||||
1: []byte("one"),
|
||||
},
|
||||
downMigrations: map[int][]byte{
|
||||
2: []byte("two"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
migrationStates[tc.name] = tc.in
|
||||
defer delete(migrationStates, tc.name)
|
||||
_, err := newStatementProvider(tc.name, -1)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in new issue