Add httpfs DB migrations and generate the input (#35)

pull/38/head
Jeff Mitchell 6 years ago committed by GitHub
parent bd5ddea6e4
commit 7f2aa5ca8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,16 +5,19 @@ THIS_FILE := $(lastword $(MAKEFILE_LIST))
TMP_DIR := $(shell mktemp -d)
REPO_PATH := github.com/hashicorp/watchtower
export APIGEN_BASEPATH := $(shell pwd)
export GEN_BASEPATH := $(shell pwd)
bootstrap:
go generate -tags tools tools/tools.go
gen: proto api
gen: proto api migrations
api:
$(MAKE) --environment-overrides -C api/internal/genapi api
migrations:
$(MAKE) --environment-overrides -C internal/db/migrations/genmigrations migrations
### oplog requires protoc-gen-go v1.20.0 or later
# GO111MODULE=on go get -u github.com/golang/protobuf/protoc-gen-go@v1.40
proto: protolint protobuild cleanup
@ -49,6 +52,6 @@ cleanup:
@rm -R ${TMP_DIR}
.PHONY: api bootstrap cleanup gen proto
.PHONY: api bootstrap cleanup gen migrations proto
.NOTPARALLEL:

@ -4,6 +4,6 @@ THIS_FILE := $(lastword $(MAKEFILE_LIST))
api:
go run -tags genapi .
goimports -w ${APIGEN_BASEPATH}/api
goimports -w ${GEN_BASEPATH}/api
.PHONY: api

@ -37,7 +37,7 @@ var createFuncs = map[string][]*createInfo{
func writeCreateFuncs() {
for outPkg, funcs := range createFuncs {
outFile := os.Getenv("APIGEN_BASEPATH") + fmt.Sprintf("/api/%s/create.gen.go", outPkg)
outFile := os.Getenv("GEN_BASEPATH") + fmt.Sprintf("/api/%s/create.gen.go", outPkg)
outBuf := bytes.NewBuffer([]byte(fmt.Sprintf(
`// Code generated by "make api"; DO NOT EDIT.
package %s
@ -92,4 +92,4 @@ func (s {{ .BaseType }}) Create{{ .TargetType }}(ctx context.Context, {{ .LowerT
return target, apiErr, nil
}
`))
`))

@ -18,9 +18,9 @@ type structInfo struct {
var inputStructs = []*structInfo{
{
os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/error.pb.go",
os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/error.pb.go",
"Error",
os.Getenv("APIGEN_BASEPATH") + "/api/error.gen.go",
os.Getenv("GEN_BASEPATH") + "/api/error.gen.go",
"Error",
"api",
"",
@ -29,9 +29,9 @@ var inputStructs = []*structInfo{
templateTypeResource,
},
{
os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/error.pb.go",
os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/error.pb.go",
"ErrorDetails",
os.Getenv("APIGEN_BASEPATH") + "/api/error_details.gen.go",
os.Getenv("GEN_BASEPATH") + "/api/error_details.gen.go",
"ErrorDetails",
"api",
"",
@ -40,9 +40,9 @@ var inputStructs = []*structInfo{
templateTypeResource,
},
{
os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host.pb.go",
os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host.pb.go",
"Host",
os.Getenv("APIGEN_BASEPATH") + "/api/hosts/host.gen.go",
os.Getenv("GEN_BASEPATH") + "/api/hosts/host.gen.go",
"Host",
"hosts",
"",
@ -51,9 +51,9 @@ var inputStructs = []*structInfo{
templateTypeResource,
},
{
os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_set.pb.go",
os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_set.pb.go",
"HostSet",
os.Getenv("APIGEN_BASEPATH") + "/api/hosts/host_set.gen.go",
os.Getenv("GEN_BASEPATH") + "/api/hosts/host_set.gen.go",
"HostSet",
"hosts",
"",
@ -62,9 +62,9 @@ var inputStructs = []*structInfo{
templateTypeResource,
},
{
os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go",
os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go",
"HostCatalog",
os.Getenv("APIGEN_BASEPATH") + "/api/hosts/host_catalog.gen.go",
os.Getenv("GEN_BASEPATH") + "/api/hosts/host_catalog.gen.go",
"HostCatalog",
"hosts",
"",
@ -73,9 +73,9 @@ var inputStructs = []*structInfo{
templateTypeResource,
},
{
os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go",
os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go",
"StaticHostCatalogDetails",
os.Getenv("APIGEN_BASEPATH") + "/api/hosts/static_host_catalog.gen.go",
os.Getenv("GEN_BASEPATH") + "/api/hosts/static_host_catalog.gen.go",
"StaticHostCatalogDetails",
"hosts",
"",
@ -84,9 +84,9 @@ var inputStructs = []*structInfo{
templateTypeDetail,
},
{
os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go",
os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go",
"AwsEc2HostCatalogDetails",
os.Getenv("APIGEN_BASEPATH") + "/api/hosts/awsec2_host_catalog.gen.go",
os.Getenv("GEN_BASEPATH") + "/api/hosts/awsec2_host_catalog.gen.go",
"AwsEc2HostCatalogDetails",
"hosts",
"",

@ -28,7 +28,7 @@ require (
github.com/hashicorp/go-uuid v1.0.2
github.com/hashicorp/hcl v1.0.0
github.com/hashicorp/vault v1.2.1-0.20200310200732-a321b1217d5a
github.com/hashicorp/vault/api v1.0.5-0.20200310200732-a321b1217d5a
github.com/hashicorp/vault/api v1.0.5-0.20200310200732-a321b1217d5a // indirect
github.com/hashicorp/vault/sdk v0.1.14-0.20200310200732-a321b1217d5a
github.com/jackc/pgx/v4 v4.6.0
github.com/jinzhu/gorm v1.9.12
@ -39,6 +39,7 @@ require (
github.com/mitchellh/cli v1.0.0
github.com/mitchellh/go-homedir v1.1.0
github.com/mitchellh/mapstructure v1.2.2
github.com/ory/dockertest v3.3.5+incompatible
github.com/ory/dockertest/v3 v3.5.4
github.com/pelletier/go-toml v1.7.0 // indirect
github.com/pires/go-proxyproto v0.0.0-20200213100827-833e5d06d8f0

@ -185,6 +185,7 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk
github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80=
github.com/cloudfoundry-community/go-cfclient v0.0.0-20190201205600-f136f9222381 h1:rdRS5BT13Iae9ssvcslol66gfOOXjaLYwqerEn/cl9s=
github.com/cloudfoundry-community/go-cfclient v0.0.0-20190201205600-f136f9222381/go.mod h1:e5+USP2j8Le2M0Jo3qKPFnNhuo1wueU4nWHCXBOfQ14=
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/cockroachdb/cockroach-go v0.0.0-20181001143604-e0a95dfd547c h1:2zRrJWIt/f9c9HhNHAgrRgq0San5gRRUJTBXLkchal0=
github.com/cockroachdb/cockroach-go v0.0.0-20181001143604-e0a95dfd547c/go.mod h1:XGLbWH/ujMcbPbhZq52Nv6UrCghb1yGn//133kEsvDk=
@ -722,6 +723,7 @@ github.com/jackc/pgconn v1.5.0 h1:oFSOilzIZkyg787M1fEmyMfOUUvwj0daqYMfaWwNL4o=
github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI=
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA=
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
@ -1063,6 +1065,7 @@ github.com/shirou/gopsutil v2.19.9+incompatible h1:IrPVlK4nfwW10DF7pW+7YJKws9Nkg
github.com/shirou/gopsutil v2.19.9+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 h1:udFKJ0aHUL60LboW/A+DfgoHVedieIzIXE8uylPue0U=
github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4/go.mod h1:qsXQc7+bwAM3Q1u/4XEfrquwF8Lw7D7y5cD8CuHnfIc=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=

@ -4,7 +4,6 @@ import (
"context"
"crypto/rand"
"crypto/tls"
"database/sql"
"errors"
"fmt"
"io"
@ -18,7 +17,6 @@ import (
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-hclog"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/internalshared/gatedwriter"
"github.com/hashicorp/vault/internalshared/reloadutil"
@ -26,9 +24,9 @@ import (
"github.com/hashicorp/vault/sdk/helper/mlock"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/watchtower/globals"
"github.com/hashicorp/watchtower/internal/db"
"github.com/hashicorp/watchtower/version"
"github.com/mitchellh/cli"
"github.com/ory/dockertest/v3"
"golang.org/x/net/http/httpproxy"
"google.golang.org/grpc/grpclog"
@ -60,12 +58,8 @@ type Server struct {
Listeners []*ServerListener
DevDatabasePassword string
DevDatabaseName string
DevDatabasePort string
dockertestPool *dockertest.Pool
dockertestResource *dockertest.Resource
DevDatabaseUrl string
DevDatabaseCleanupFunc func() error
}
func NewServer() *Server {
@ -372,54 +366,23 @@ func (b *Server) RunShutdownFuncs(ui cli.Ui) {
}
func (b *Server) CreateDevDatabase() error {
pool, err := dockertest.NewPool("")
if err != nil {
return fmt.Errorf("could not connect to docker: %w", err)
}
resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=watchtower"})
c, url, err := db.InitDbInDocker("postgres")
if err != nil {
return fmt.Errorf("could not start resource: %w", err)
}
if err := pool.Retry(func() error {
db, err := sql.Open("postgres", fmt.Sprintf("postgres://postgres:secret@localhost:%s/watchtower?sslmode=disable", resource.GetPort("5432/tcp")))
if err != nil {
return fmt.Errorf("error opening postgres dev container: %w", err)
}
var mErr *multierror.Error
if err := db.Ping(); err != nil {
mErr = multierror.Append(fmt.Errorf("error pinging dev database container: %w", err))
}
if err := db.Close(); err != nil {
mErr = multierror.Append(fmt.Errorf("error closing dev database container: %w", err))
}
return mErr.ErrorOrNil()
}); err != nil {
return fmt.Errorf("could not connect to docker: %w", err)
c()
return fmt.Errorf("unable to start dev database: %w", err)
}
b.dockertestPool = pool
b.dockertestResource = resource
b.DevDatabaseName = "watchtower"
b.DevDatabasePassword = "secret"
b.DevDatabasePort = resource.GetPort("5432/tcp")
b.DevDatabaseCleanupFunc = c
b.DevDatabaseUrl = url
b.InfoKeys = append(b.InfoKeys, "dev database name", "dev database password", "dev database port")
b.Info["dev database name"] = "watchtower"
b.Info["dev database password"] = "secret"
b.Info["dev database port"] = b.DevDatabasePort
b.InfoKeys = append(b.InfoKeys, "dev database url")
b.Info["dev database url"] = b.DevDatabaseUrl
return nil
}
func (b *Server) DestroyDevDatabase() error {
if b.dockertestPool == nil {
return nil
if b.DevDatabaseCleanupFunc != nil {
return b.DevDatabaseCleanupFunc()
}
if b.dockertestResource == nil {
return errors.New("found a pool for dev database container but no resource")
}
return b.dockertestPool.Purge(b.dockertestResource)
return errors.New("no dev database cleanup function found")
}

@ -8,6 +8,7 @@ import (
"strings"
"github.com/golang-migrate/migrate/v4"
"github.com/hashicorp/watchtower/internal/db/migrations"
"github.com/jinzhu/gorm"
"github.com/ory/dockertest"
)
@ -55,28 +56,57 @@ func Migrate(connectionUrl string, migrationsDirectory string) error {
return nil
}
// InitDbInDocker initializes the data store within docker or an existing PG_URL
func InitDbInDocker(dialect string) (cleanup func() error, retURL string, err error) {
switch dialect {
case "postgres":
if os.Getenv("PG_URL") != "" {
if err := InitStore(dialect, func() error { return nil }, os.Getenv("PG_URL")); err != nil {
return func() error { return nil }, os.Getenv("PG_URL"), fmt.Errorf("error initializing store: %w", err)
}
return func() error { return nil }, os.Getenv("PG_URL"), nil
}
}
c, url, err := StartDbInDocker(dialect)
if err != nil {
return func() error { return nil }, "", fmt.Errorf("could not start docker: %w", err)
}
if err := InitStore(dialect, c, url); err != nil {
return func() error { return nil }, "", fmt.Errorf("error initializing store: %w", err)
}
return c, url, nil
}
// StartDbInDocker
func StartDbInDocker() (cleanup func() error, retURL string, err error) {
func StartDbInDocker(dialect string) (cleanup func() error, retURL string, err error) {
pool, err := dockertest.NewPool("")
if err != nil {
return func() error { return nil }, "", fmt.Errorf("could not connect to docker: %w", err)
}
resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=watchtower"})
var resource *dockertest.Resource
var url string
switch dialect {
case "postgres":
resource, err = pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=watchtower"})
url = "postgres://postgres:secret@localhost:%s?sslmode=disable"
default:
panic(fmt.Sprintf("unknown dialect %q", dialect))
}
if err != nil {
return func() error { return nil }, "", fmt.Errorf("could not start resource: %w", err)
}
c := func() error {
cleanup = func() error {
return cleanupDockerResource(pool, resource)
}
url := fmt.Sprintf("postgres://postgres:secret@localhost:%s?sslmode=disable", resource.GetPort("5432/tcp"))
url = fmt.Sprintf(url, resource.GetPort("5432/tcp"))
if err := pool.Retry(func() error {
db, err := sql.Open("postgres", url)
db, err := sql.Open(dialect, url)
if err != nil {
return fmt.Errorf("error opening postgres dev container: %w", err)
return fmt.Errorf("error opening %s dev container: %w", dialect, err)
}
if err := db.Ping(); err != nil {
@ -87,7 +117,29 @@ func StartDbInDocker() (cleanup func() error, retURL string, err error) {
}); err != nil {
return func() error { return nil }, "", fmt.Errorf("could not connect to docker: %w", err)
}
return c, url, nil
return cleanup, url, nil
}
// InitStore will execute the migrations needed to initialize the store for tests
func InitStore(dialect string, cleanup func() error, url string) error {
// run migrations
source, err := migrations.NewMigrationSource(dialect)
if err != nil {
cleanup()
return fmt.Errorf("error creating migration driver: %w", err)
}
m, err := migrate.NewWithSourceInstance("httpfs", source, url)
if err != nil {
cleanup()
return fmt.Errorf("error creating migrations: %w", err)
}
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
cleanup()
return fmt.Errorf("error running migrations: %w", err)
}
return nil
}
// cleanupDockerResource will clean up the dockertest resources (postgres)

@ -5,7 +5,7 @@ import (
)
func TestOpen(t *testing.T) {
cleanup, url, err := StartDbInDocker()
cleanup, url, err := StartDbInDocker("postgres")
if err != nil {
t.Fatal(err)
}
@ -56,7 +56,7 @@ func TestOpen(t *testing.T) {
}
func TestMigrate(t *testing.T) {
cleanup, url, err := StartDbInDocker()
cleanup, url, err := StartDbInDocker("postgres")
if err != nil {
t.Fatal(err)
}

@ -21,7 +21,7 @@ returning id;
`
)
cleanup, conn := TestSetup(t, "migrations/postgres")
cleanup, conn := TestSetup(t, "postgres")
defer cleanup()
defer conn.Close()

@ -0,0 +1,121 @@
package migrations
import (
"bytes"
"fmt"
"net/http"
"os"
"sort"
"strings"
"time"
"github.com/golang-migrate/migrate/v4/source"
"github.com/golang-migrate/migrate/v4/source/httpfs"
)
// migrationDriver satisfies the remaining need of the Driver interface, since
// the package uses PartialDriver under the hood
type migrationDriver struct {
dialect string
}
// Open returns the given "file"
func (m *migrationDriver) Open(name string) (http.File, error) {
var ff *fakeFile
switch m.dialect {
case "postgres":
ff = postgresMigrations[name]
}
if ff == nil {
return nil, os.ErrNotExist
}
ff.name = strings.TrimPrefix(name, "migrations/")
ff.reader = bytes.NewReader(ff.bytes)
ff.dialect = m.dialect
return ff, nil
}
// NewMigrationSource creates a source.Driver using httpfs with the given dialect
func NewMigrationSource(dialect string) (source.Driver, error) {
switch dialect {
case "postgres":
default:
return nil, fmt.Errorf("unknown migrations dialect %s", dialect)
}
return httpfs.New(&migrationDriver{dialect}, "migrations")
}
// fakeFile is used to satisfy the http.File interface
type fakeFile struct {
name string
bytes []byte
reader *bytes.Reader
dialect string
}
func (f *fakeFile) Read(p []byte) (n int, err error) {
return f.reader.Read(p)
}
func (f *fakeFile) Seek(offset int64, whence int) (int64, error) {
return f.reader.Seek(offset, whence)
}
func (f *fakeFile) Close() error { return nil }
// Readdir returns os.FileInfo values, in sorted order, and eliding the
// migrations "dir"
func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) {
// Get the right map
var migrationsMap map[string]*fakeFile
switch f.dialect {
case "postgres":
migrationsMap = postgresMigrations
default:
return nil, fmt.Errorf("unknown database dialect %s", f.dialect)
}
// Sort the keys. May not be necessary but feels nice.
keys := make([]string, 0, len(migrationsMap))
for k := range migrationsMap {
keys = append(keys, k)
}
sort.Strings(keys)
// Create the slice of fileinfo objects to return
ret := make([]os.FileInfo, 0, len(migrationsMap))
for _, v := range keys {
// We need "migrations" in the map for the initial Open call but we
// should not return it as part of the "directory"'s "files".
if v == "migrations" {
continue
}
stat, err := migrationsMap[v].Stat()
if err != nil {
return nil, err
}
ret = append(ret, stat)
}
return ret, nil
}
// Stat returns a new fakeFileInfo object with the necessary bits
func (f *fakeFile) Stat() (os.FileInfo, error) {
return &fakeFileInfo{
name: f.name,
size: int64(len(f.bytes)),
}, nil
}
// fakeFileInfo satisfies os.FileInfo but represents our fake "files"
type fakeFileInfo struct {
name string
size int64
}
func (f *fakeFileInfo) Name() string { return f.name }
func (f *fakeFileInfo) Size() int64 { return f.size }
func (f *fakeFileInfo) Mode() os.FileMode { return os.ModePerm }
func (f *fakeFileInfo) ModTime() time.Time { return time.Now() }
func (f *fakeFileInfo) IsDir() bool { return false }
func (f *fakeFileInfo) Sys() interface{} { return nil }

@ -0,0 +1,8 @@
# Be sure to place this BEFORE `include` directives, if any.
THIS_FILE := $(lastword $(MAKEFILE_LIST))
migrations:
go run -tags genmigrations .
goimports -w ${GEN_BASEPATH}/internal/db/migrations
.PHONY: migrations

@ -0,0 +1,88 @@
// +build genmigrations
package main
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"sort"
"strings"
"text/template"
)
// generate looks for migration sql in a directory for the given dialect and
// applies the templates below to the contents of the files, building up a
// migrations map for the dialect
func generate(dialect string) {
baseDir := os.Getenv("GEN_BASEPATH") + fmt.Sprint("/internal/db/migrations")
dir, err := os.Open(fmt.Sprintf("%s/%s", baseDir, dialect))
if err != nil {
fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err)
os.Exit(1)
}
names, err := dir.Readdirnames(0)
if err != nil {
fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err)
os.Exit(1)
}
outBuf := bytes.NewBuffer(nil)
valuesBuf := bytes.NewBuffer(nil)
sort.Strings(names)
for _, name := range names {
if !strings.HasSuffix(name, ".sql") {
continue
}
contents, err := ioutil.ReadFile(fmt.Sprintf("%s/%s/%s", baseDir, dialect, name))
if err != nil {
fmt.Printf("error opening file %s with dialect %s: %v", name, dialect, err)
os.Exit(1)
}
migrationsValueTemplate.Execute(valuesBuf, struct {
Name string
Contents string
}{
Name: name,
Contents: string(contents),
})
}
migrationsTemplate.Execute(outBuf, struct {
Type string
Values string
}{
Type: dialect,
Values: valuesBuf.String(),
})
outFile := fmt.Sprintf("%s/%s.gen.go", baseDir, dialect)
if err := ioutil.WriteFile(outFile, outBuf.Bytes(), 0644); err != nil {
fmt.Printf("error writing file %q: %v\n", outFile, err)
os.Exit(1)
}
}
var migrationsTemplate = template.Must(template.New("").Parse(
`// Code generated by "make migrations"; DO NOT EDIT.
package migrations
import (
"bytes"
)
var {{ .Type }}Migrations = map[string]*fakeFile{
"migrations": {
name: "migrations",
},
{{ .Values }}
}
`))
var migrationsValueTemplate = template.Must(template.New("").Parse(
`"migrations/{{ .Name }}": {
name: "{{ .Name }}",
bytes: []byte(` + "`\n{{ .Contents }}\n`" + `),
},
`))

@ -0,0 +1,7 @@
// +build genmigrations
package main
func main() {
generate("postgres")
}

@ -0,0 +1,225 @@
// Code generated by "make migrations"; DO NOT EDIT.
package migrations
var postgresMigrations = map[string]*fakeFile{
"migrations": {
name: "migrations",
},
"migrations/01_oplog.up.sql": {
name: "01_oplog.up.sql",
bytes: []byte(`
CREATE TABLE if not exists oplog_entry (
id bigint generated always as identity primary key,
create_time timestamp with time zone default current_timestamp,
update_time timestamp with time zone default current_timestamp,
version text NOT NULL,
aggregate_name text NOT NULL,
"data" bytea NOT NULL
);
CREATE TABLE if not exists oplog_ticket (
id bigint generated always as identity primary key,
create_time timestamp with time zone default current_timestamp,
update_time timestamp with time zone default current_timestamp,
"name" text NOT NULL UNIQUE,
"version" bigint NOT NULL
);
CREATE TABLE if not exists oplog_metadata (
id bigint generated always as identity primary key,
create_time timestamp with time zone default current_timestamp,
entry_id bigint NOT NULL REFERENCES oplog_entry(id) ON DELETE CASCADE ON UPDATE CASCADE,
"key" text NOT NULL,
value text NULL
);
create index if not exists idx_oplog_metatadata_key on oplog_metadata(key);
create index if not exists idx_oplog_metatadata_value on oplog_metadata(value);
INSERT INTO oplog_ticket (name, version)
values
('default', 1),
('iam_scope', 1),
('iam_user', 1),
('iam_auth_method', 1),
('iam_group', 1),
('iam_group_member_user', 1),
('iam_role', 1),
('iam_role_grant', 1),
('iam_role_group', 1),
('iam_role_user', 1),
('db_test_user', 1),
('db_test_car', 1),
('db_test_rental', 1);
`),
},
"migrations/02_domain_types.down.sql": {
name: "02_domain_types.down.sql",
bytes: []byte(`
begin;
drop domain wt_public_id;
commit;
`),
},
"migrations/02_domain_types.up.sql": {
name: "02_domain_types.up.sql",
bytes: []byte(`
begin;
create domain wt_public_id as text
check(
length(trim(value)) > 10
);
comment on domain wt_public_id is
'Random ID generated with github.com/hashicorp/vault/sdk/helper/base62';
commit;
`),
},
"migrations/03_db.down.sql": {
name: "03_db.down.sql",
bytes: []byte(`
drop table if exists db_test_user;
drop table if exists db_test_car;
drop table if exists db_test_rental;
`),
},
"migrations/03_db.up.sql": {
name: "03_db.up.sql",
bytes: []byte(`
-- create test tables used in the unit tests for the internal/db package
-- these tables (db_test_user, db_test_car, db_test_rental) are not part
-- of the Watchtower domain model... they are simply used for testing the internal/db package
CREATE TABLE if not exists db_test_user (
id bigint generated always as identity primary key,
create_time timestamp with time zone default current_timestamp,
update_time timestamp with time zone default current_timestamp,
public_id text NOT NULL UNIQUE,
name text UNIQUE,
phone_number text,
email text
);
CREATE TABLE if not exists db_test_car (
id bigint generated always as identity primary key,
create_time timestamp with time zone default current_timestamp,
update_time timestamp with time zone default current_timestamp,
public_id text NOT NULL UNIQUE,
name text UNIQUE,
model text,
mpg smallint
);
CREATE TABLE if not exists db_test_rental (
id bigint generated always as identity primary key,
create_time timestamp with time zone default current_timestamp,
update_time timestamp with time zone default current_timestamp,
public_id text NOT NULL UNIQUE,
name text UNIQUE,
user_id bigint not null REFERENCES db_test_user(id),
car_id bigint not null REFERENCES db_test_car(id)
);
`),
},
"migrations/04_iam.down.sql": {
name: "04_iam.down.sql",
bytes: []byte(`
BEGIN;
drop table if exists iam_scope CASCADE;
drop trigger if exists iam_scope_insert;
drop function if exists iam_sub_scopes_func;
COMMIT;
`),
},
"migrations/04_iam.up.sql": {
name: "04_iam.up.sql",
bytes: []byte(`
BEGIN;
CREATE TABLE if not exists iam_scope_type_enm (
string text NOT NULL primary key CHECK(string IN ('unknown', 'organization', 'project'))
);
INSERT INTO iam_scope_type_enm (string)
values
('unknown'),
('organization'),
('project');
CREATE TABLE if not exists iam_scope (
public_id wt_public_id primary key,
create_time timestamp with time zone default current_timestamp,
update_time timestamp with time zone default current_timestamp,
name text,
type text NOT NULL REFERENCES iam_scope_type_enm(string) CHECK(
(
type = 'organization'
and parent_id = NULL
)
or (
type = 'project'
and parent_id IS NOT NULL
)
),
description text,
parent_id text REFERENCES iam_scope(public_id) ON DELETE CASCADE ON UPDATE CASCADE
);
create table if not exists iam_scope_organization (
scope_id wt_public_id NOT NULL UNIQUE REFERENCES iam_scope(public_id) ON DELETE CASCADE ON UPDATE CASCADE,
name text UNIQUE,
primary key(scope_id)
);
create table if not exists iam_scope_project (
scope_id wt_public_id NOT NULL REFERENCES iam_scope(public_id) ON DELETE CASCADE ON UPDATE CASCADE,
parent_id wt_public_id NOT NULL REFERENCES iam_scope_organization(scope_id) ON DELETE CASCADE ON UPDATE CASCADE,
name text,
unique(parent_id, name),
primary key(scope_id, parent_id)
);
CREATE
OR REPLACE FUNCTION iam_sub_scopes_func() RETURNS TRIGGER
SET SCHEMA
'public' LANGUAGE plpgsql AS $$ DECLARE parent_type INT;
BEGIN IF new.type = 'organization' THEN
insert into iam_scope_organization (scope_id, name)
values
(new.public_id, new.name);
return NEW;
END IF;
IF new.type = 'project' THEN
insert into iam_scope_project (scope_id, parent_id, name)
values
(new.public_id, new.parent_id, new.name);
return NEW;
END IF;
RAISE EXCEPTION 'unknown scope type';
END;
$$;
CREATE TRIGGER iam_scope_insert
AFTER
insert ON iam_scope FOR EACH ROW EXECUTE PROCEDURE iam_sub_scopes_func();
CREATE
OR REPLACE FUNCTION iam_immutable_scope_type_func() RETURNS TRIGGER
SET SCHEMA
'public' LANGUAGE plpgsql AS $$ DECLARE parent_type INT;
BEGIN IF new.type != old.type THEN
RAISE EXCEPTION 'scope type cannot be updated';
END IF;
return NEW;
END;
$$;
CREATE TRIGGER iam_scope_update
BEFORE
update ON iam_scope FOR EACH ROW EXECUTE PROCEDURE iam_immutable_scope_type_func();
COMMIT;
`),
},
}

@ -15,7 +15,7 @@ import (
func TestGormReadWriter_Update(t *testing.T) {
// intentionally not run with t.Parallel so we don't need to use DoTx for the Update tests
cleanup, db := TestSetup(t, "migrations/postgres")
cleanup, db := TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer db.Close()
@ -163,7 +163,7 @@ func TestGormReadWriter_Update(t *testing.T) {
func TestGormReadWriter_Create(t *testing.T) {
// intentionally not run with t.Parallel so we don't need to use DoTx for the Create tests
cleanup, db := TestSetup(t, "migrations/postgres")
cleanup, db := TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer db.Close()
@ -271,7 +271,7 @@ func TestGormReadWriter_Create(t *testing.T) {
func TestGormReadWriter_LookupByName(t *testing.T) {
t.Parallel()
cleanup, db := TestSetup(t, "migrations/postgres")
cleanup, db := TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer db.Close()
@ -326,7 +326,7 @@ func TestGormReadWriter_LookupByName(t *testing.T) {
func TestGormReadWriter_LookupByPublicId(t *testing.T) {
t.Parallel()
cleanup, db := TestSetup(t, "migrations/postgres")
cleanup, db := TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer db.Close()
@ -381,7 +381,7 @@ func TestGormReadWriter_LookupByPublicId(t *testing.T) {
func TestGormReadWriter_LookupWhere(t *testing.T) {
t.Parallel()
cleanup, db := TestSetup(t, "migrations/postgres")
cleanup, db := TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer db.Close()
@ -431,7 +431,7 @@ func TestGormReadWriter_LookupWhere(t *testing.T) {
func TestGormReadWriter_SearchWhere(t *testing.T) {
t.Parallel()
cleanup, db := TestSetup(t, "migrations/postgres")
cleanup, db := TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer db.Close()
@ -481,7 +481,7 @@ func TestGormReadWriter_SearchWhere(t *testing.T) {
func TestGormReadWriter_DB(t *testing.T) {
t.Parallel()
cleanup, db := TestSetup(t, "migrations/postgres")
cleanup, db := TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer db.Close()
@ -504,7 +504,7 @@ func TestGormReadWriter_DB(t *testing.T) {
func TestGormReadWriter_DoTx(t *testing.T) {
t.Parallel()
cleanup, db := TestSetup(t, "migrations/postgres")
cleanup, db := TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer db.Close()
@ -604,7 +604,7 @@ func TestGormReadWriter_DoTx(t *testing.T) {
func TestGormReadWriter_Delete(t *testing.T) {
// intentionally not run with t.Parallel so we don't need to use DoTx for the Create tests
cleanup, db := TestSetup(t, "migrations/postgres")
cleanup, db := TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer db.Close()
@ -781,7 +781,7 @@ func TestGormReadWriter_Delete(t *testing.T) {
func TestGormReadWriter_ScanRows(t *testing.T) {
t.Parallel()
cleanup, db := TestSetup(t, "migrations/postgres")
cleanup, db := TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer db.Close()

@ -2,11 +2,8 @@ package db
import (
"crypto/rand"
"fmt"
"os"
"testing"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file"
wrapping "github.com/hashicorp/go-kms-wrapping"
@ -15,19 +12,15 @@ import (
)
// setup the tests (initialize the database one-time and intialized testDatabaseURL)
func TestSetup(t *testing.T, migrationsDirectory string) (func() error, *gorm.DB) {
if _, err := os.Stat(migrationsDirectory); os.IsNotExist(err) {
t.Fatal("error migrationsDirectory does not exist")
}
func TestSetup(t *testing.T, dialect string) (func() error, *gorm.DB) {
cleanup := func() error { return nil }
var url string
var err error
cleanup, url, err = initDbInDocker(t, migrationsDirectory)
cleanup, url, err = InitDbInDocker(dialect)
if err != nil {
t.Fatal(err)
}
db, err := gorm.Open("postgres", url)
db, err := gorm.Open(dialect, url)
if err != nil {
t.Fatal(err)
}
@ -51,30 +44,3 @@ func TestWrapper(t *testing.T) wrapping.Wrapper {
return root
}
// initDbInDocker initializes postgres within dockertest for the unit tests
func initDbInDocker(t *testing.T, migrationsDirectory string) (cleanup func() error, retURL string, err error) {
if os.Getenv("PG_URL") != "" {
TestInitStore(t, func() error { return nil }, os.Getenv("PG_URL"), migrationsDirectory)
return func() error { return nil }, os.Getenv("PG_URL"), nil
}
c, url, err := StartDbInDocker()
if err != nil {
return func() error { return nil }, "", fmt.Errorf("could not start docker: %w", err)
}
TestInitStore(t, c, url, migrationsDirectory)
return c, url, nil
}
// TestInitStore will execute the migrations needed to initialize the store for tests
func TestInitStore(t *testing.T, cleanup func() error, url string, migrationsDirectory string) {
// run migrations
m, err := migrate.New(fmt.Sprintf("file://%s", migrationsDirectory), url)
if err != nil {
cleanup()
t.Fatalf("Error creating migrations: %s", err)
}
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
cleanup()
t.Fatalf("Error running migrations: %s", err)
}
}

@ -6,7 +6,7 @@ import (
func Test_Utils(t *testing.T) {
t.Parallel()
cleanup, conn := TestSetup(t, "migrations/postgres")
cleanup, conn := TestSetup(t, "postgres")
defer cleanup()
defer conn.Close()
t.Run("nothing", func(t *testing.T) {

@ -13,7 +13,7 @@ import (
func Test_Repository_CreateScope(t *testing.T) {
t.Parallel()
cleanup, conn := db.TestSetup(t, "../db/migrations/postgres")
cleanup, conn := db.TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer conn.Close()
@ -52,7 +52,7 @@ func Test_Repository_CreateScope(t *testing.T) {
func Test_Repository_UpdateScope(t *testing.T) {
t.Parallel()
cleanup, conn := db.TestSetup(t, "../db/migrations/postgres")
cleanup, conn := db.TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer conn.Close()

@ -15,7 +15,7 @@ import (
func TestNewRepository(t *testing.T) {
t.Parallel()
cleanup, conn := db.TestSetup(t, "../db/migrations/postgres")
cleanup, conn := db.TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer conn.Close()
@ -100,7 +100,7 @@ func TestNewRepository(t *testing.T) {
}
func Test_Repository_create(t *testing.T) {
t.Parallel()
cleanup, conn := db.TestSetup(t, "../db/migrations/postgres")
cleanup, conn := db.TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer conn.Close()
@ -149,7 +149,7 @@ func Test_Repository_create(t *testing.T) {
func Test_dbRepository_update(t *testing.T) {
t.Parallel()
cleanup, conn := db.TestSetup(t, "../db/migrations/postgres")
cleanup, conn := db.TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer conn.Close()

@ -12,7 +12,7 @@ import (
func Test_NewScope(t *testing.T) {
t.Parallel()
cleanup, conn := db.TestSetup(t, "../db/migrations/postgres")
cleanup, conn := db.TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer conn.Close()
@ -49,7 +49,7 @@ func Test_NewScope(t *testing.T) {
}
func Test_ScopeCreate(t *testing.T) {
t.Parallel()
cleanup, conn := db.TestSetup(t, "../db/migrations/postgres")
cleanup, conn := db.TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer conn.Close()
@ -89,7 +89,7 @@ func Test_ScopeCreate(t *testing.T) {
func Test_ScopeUpdate(t *testing.T) {
t.Parallel()
cleanup, conn := db.TestSetup(t, "../db/migrations/postgres")
cleanup, conn := db.TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer conn.Close()
@ -125,7 +125,7 @@ func Test_ScopeUpdate(t *testing.T) {
}
func Test_ScopeGetScope(t *testing.T) {
t.Parallel()
cleanup, conn := db.TestSetup(t, "../db/migrations/postgres")
cleanup, conn := db.TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer conn.Close()
@ -176,7 +176,7 @@ func TestScope_ResourceType(t *testing.T) {
func TestScope_Clone(t *testing.T) {
t.Parallel()
cleanup, conn := db.TestSetup(t, "../db/migrations/postgres")
cleanup, conn := db.TestSetup(t, "postgres")
defer cleanup()
assert := assert.New(t)
defer conn.Close()

Loading…
Cancel
Save