diff --git a/Makefile b/Makefile index 76a7c84be4..a2827dff01 100644 --- a/Makefile +++ b/Makefile @@ -5,16 +5,19 @@ THIS_FILE := $(lastword $(MAKEFILE_LIST)) TMP_DIR := $(shell mktemp -d) REPO_PATH := github.com/hashicorp/watchtower -export APIGEN_BASEPATH := $(shell pwd) +export GEN_BASEPATH := $(shell pwd) bootstrap: go generate -tags tools tools/tools.go -gen: proto api +gen: proto api migrations api: $(MAKE) --environment-overrides -C api/internal/genapi api +migrations: + $(MAKE) --environment-overrides -C internal/db/migrations/genmigrations migrations + ### oplog requires protoc-gen-go v1.20.0 or later # GO111MODULE=on go get -u github.com/golang/protobuf/protoc-gen-go@v1.40 proto: protolint protobuild cleanup @@ -49,6 +52,6 @@ cleanup: @rm -R ${TMP_DIR} -.PHONY: api bootstrap cleanup gen proto +.PHONY: api bootstrap cleanup gen migrations proto .NOTPARALLEL: diff --git a/api/internal/genapi/Makefile b/api/internal/genapi/Makefile index 7124d1b5d6..7e68a83897 100644 --- a/api/internal/genapi/Makefile +++ b/api/internal/genapi/Makefile @@ -4,6 +4,6 @@ THIS_FILE := $(lastword $(MAKEFILE_LIST)) api: go run -tags genapi . - goimports -w ${APIGEN_BASEPATH}/api + goimports -w ${GEN_BASEPATH}/api .PHONY: api diff --git a/api/internal/genapi/create.go b/api/internal/genapi/create.go index 151515bb26..0612eb097d 100644 --- a/api/internal/genapi/create.go +++ b/api/internal/genapi/create.go @@ -37,7 +37,7 @@ var createFuncs = map[string][]*createInfo{ func writeCreateFuncs() { for outPkg, funcs := range createFuncs { - outFile := os.Getenv("APIGEN_BASEPATH") + fmt.Sprintf("/api/%s/create.gen.go", outPkg) + outFile := os.Getenv("GEN_BASEPATH") + fmt.Sprintf("/api/%s/create.gen.go", outPkg) outBuf := bytes.NewBuffer([]byte(fmt.Sprintf( `// Code generated by "make api"; DO NOT EDIT. package %s @@ -92,4 +92,4 @@ func (s {{ .BaseType }}) Create{{ .TargetType }}(ctx context.Context, {{ .LowerT return target, apiErr, nil } - `)) +`)) diff --git a/api/internal/genapi/input.go b/api/internal/genapi/input.go index c45d0baab2..4878f33ebd 100644 --- a/api/internal/genapi/input.go +++ b/api/internal/genapi/input.go @@ -18,9 +18,9 @@ type structInfo struct { var inputStructs = []*structInfo{ { - os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/error.pb.go", + os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/error.pb.go", "Error", - os.Getenv("APIGEN_BASEPATH") + "/api/error.gen.go", + os.Getenv("GEN_BASEPATH") + "/api/error.gen.go", "Error", "api", "", @@ -29,9 +29,9 @@ var inputStructs = []*structInfo{ templateTypeResource, }, { - os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/error.pb.go", + os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/error.pb.go", "ErrorDetails", - os.Getenv("APIGEN_BASEPATH") + "/api/error_details.gen.go", + os.Getenv("GEN_BASEPATH") + "/api/error_details.gen.go", "ErrorDetails", "api", "", @@ -40,9 +40,9 @@ var inputStructs = []*structInfo{ templateTypeResource, }, { - os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host.pb.go", + os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host.pb.go", "Host", - os.Getenv("APIGEN_BASEPATH") + "/api/hosts/host.gen.go", + os.Getenv("GEN_BASEPATH") + "/api/hosts/host.gen.go", "Host", "hosts", "", @@ -51,9 +51,9 @@ var inputStructs = []*structInfo{ templateTypeResource, }, { - os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_set.pb.go", + os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_set.pb.go", "HostSet", - os.Getenv("APIGEN_BASEPATH") + "/api/hosts/host_set.gen.go", + os.Getenv("GEN_BASEPATH") + "/api/hosts/host_set.gen.go", "HostSet", "hosts", "", @@ -62,9 +62,9 @@ var inputStructs = []*structInfo{ templateTypeResource, }, { - os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go", + os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go", "HostCatalog", - os.Getenv("APIGEN_BASEPATH") + "/api/hosts/host_catalog.gen.go", + os.Getenv("GEN_BASEPATH") + "/api/hosts/host_catalog.gen.go", "HostCatalog", "hosts", "", @@ -73,9 +73,9 @@ var inputStructs = []*structInfo{ templateTypeResource, }, { - os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go", + os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go", "StaticHostCatalogDetails", - os.Getenv("APIGEN_BASEPATH") + "/api/hosts/static_host_catalog.gen.go", + os.Getenv("GEN_BASEPATH") + "/api/hosts/static_host_catalog.gen.go", "StaticHostCatalogDetails", "hosts", "", @@ -84,9 +84,9 @@ var inputStructs = []*structInfo{ templateTypeDetail, }, { - os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go", + os.Getenv("GEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go", "AwsEc2HostCatalogDetails", - os.Getenv("APIGEN_BASEPATH") + "/api/hosts/awsec2_host_catalog.gen.go", + os.Getenv("GEN_BASEPATH") + "/api/hosts/awsec2_host_catalog.gen.go", "AwsEc2HostCatalogDetails", "hosts", "", diff --git a/go.mod b/go.mod index 866f8c4e3b..62046dda63 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,7 @@ require ( github.com/hashicorp/go-uuid v1.0.2 github.com/hashicorp/hcl v1.0.0 github.com/hashicorp/vault v1.2.1-0.20200310200732-a321b1217d5a - github.com/hashicorp/vault/api v1.0.5-0.20200310200732-a321b1217d5a + github.com/hashicorp/vault/api v1.0.5-0.20200310200732-a321b1217d5a // indirect github.com/hashicorp/vault/sdk v0.1.14-0.20200310200732-a321b1217d5a github.com/jackc/pgx/v4 v4.6.0 github.com/jinzhu/gorm v1.9.12 @@ -39,6 +39,7 @@ require ( github.com/mitchellh/cli v1.0.0 github.com/mitchellh/go-homedir v1.1.0 github.com/mitchellh/mapstructure v1.2.2 + github.com/ory/dockertest v3.3.5+incompatible github.com/ory/dockertest/v3 v3.5.4 github.com/pelletier/go-toml v1.7.0 // indirect github.com/pires/go-proxyproto v0.0.0-20200213100827-833e5d06d8f0 diff --git a/go.sum b/go.sum index eda7970a1e..639b415776 100644 --- a/go.sum +++ b/go.sum @@ -185,6 +185,7 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80= github.com/cloudfoundry-community/go-cfclient v0.0.0-20190201205600-f136f9222381 h1:rdRS5BT13Iae9ssvcslol66gfOOXjaLYwqerEn/cl9s= github.com/cloudfoundry-community/go-cfclient v0.0.0-20190201205600-f136f9222381/go.mod h1:e5+USP2j8Le2M0Jo3qKPFnNhuo1wueU4nWHCXBOfQ14= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/cockroachdb/cockroach-go v0.0.0-20181001143604-e0a95dfd547c h1:2zRrJWIt/f9c9HhNHAgrRgq0San5gRRUJTBXLkchal0= github.com/cockroachdb/cockroach-go v0.0.0-20181001143604-e0a95dfd547c/go.mod h1:XGLbWH/ujMcbPbhZq52Nv6UrCghb1yGn//133kEsvDk= @@ -722,6 +723,7 @@ github.com/jackc/pgconn v1.5.0 h1:oFSOilzIZkyg787M1fEmyMfOUUvwj0daqYMfaWwNL4o= github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= @@ -1063,6 +1065,7 @@ github.com/shirou/gopsutil v2.19.9+incompatible h1:IrPVlK4nfwW10DF7pW+7YJKws9Nkg github.com/shirou/gopsutil v2.19.9+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 h1:udFKJ0aHUL60LboW/A+DfgoHVedieIzIXE8uylPue0U= github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4/go.mod h1:qsXQc7+bwAM3Q1u/4XEfrquwF8Lw7D7y5cD8CuHnfIc= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= diff --git a/internal/cmd/base/servers.go b/internal/cmd/base/servers.go index 900db697b1..a2bf441719 100644 --- a/internal/cmd/base/servers.go +++ b/internal/cmd/base/servers.go @@ -4,7 +4,6 @@ import ( "context" "crypto/rand" "crypto/tls" - "database/sql" "errors" "fmt" "io" @@ -18,7 +17,6 @@ import ( "github.com/hashicorp/errwrap" "github.com/hashicorp/go-hclog" wrapping "github.com/hashicorp/go-kms-wrapping" - "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/internalshared/gatedwriter" "github.com/hashicorp/vault/internalshared/reloadutil" @@ -26,9 +24,9 @@ import ( "github.com/hashicorp/vault/sdk/helper/mlock" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/watchtower/globals" + "github.com/hashicorp/watchtower/internal/db" "github.com/hashicorp/watchtower/version" "github.com/mitchellh/cli" - "github.com/ory/dockertest/v3" "golang.org/x/net/http/httpproxy" "google.golang.org/grpc/grpclog" @@ -60,12 +58,8 @@ type Server struct { Listeners []*ServerListener - DevDatabasePassword string - DevDatabaseName string - DevDatabasePort string - - dockertestPool *dockertest.Pool - dockertestResource *dockertest.Resource + DevDatabaseUrl string + DevDatabaseCleanupFunc func() error } func NewServer() *Server { @@ -372,54 +366,23 @@ func (b *Server) RunShutdownFuncs(ui cli.Ui) { } func (b *Server) CreateDevDatabase() error { - pool, err := dockertest.NewPool("") - if err != nil { - return fmt.Errorf("could not connect to docker: %w", err) - } - - resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=watchtower"}) + c, url, err := db.InitDbInDocker("postgres") if err != nil { - return fmt.Errorf("could not start resource: %w", err) - } - - if err := pool.Retry(func() error { - db, err := sql.Open("postgres", fmt.Sprintf("postgres://postgres:secret@localhost:%s/watchtower?sslmode=disable", resource.GetPort("5432/tcp"))) - if err != nil { - return fmt.Errorf("error opening postgres dev container: %w", err) - } - var mErr *multierror.Error - if err := db.Ping(); err != nil { - mErr = multierror.Append(fmt.Errorf("error pinging dev database container: %w", err)) - } - if err := db.Close(); err != nil { - mErr = multierror.Append(fmt.Errorf("error closing dev database container: %w", err)) - } - return mErr.ErrorOrNil() - }); err != nil { - return fmt.Errorf("could not connect to docker: %w", err) + c() + return fmt.Errorf("unable to start dev database: %w", err) } - b.dockertestPool = pool - b.dockertestResource = resource - b.DevDatabaseName = "watchtower" - b.DevDatabasePassword = "secret" - b.DevDatabasePort = resource.GetPort("5432/tcp") + b.DevDatabaseCleanupFunc = c + b.DevDatabaseUrl = url - b.InfoKeys = append(b.InfoKeys, "dev database name", "dev database password", "dev database port") - b.Info["dev database name"] = "watchtower" - b.Info["dev database password"] = "secret" - b.Info["dev database port"] = b.DevDatabasePort + b.InfoKeys = append(b.InfoKeys, "dev database url") + b.Info["dev database url"] = b.DevDatabaseUrl return nil } func (b *Server) DestroyDevDatabase() error { - if b.dockertestPool == nil { - return nil + if b.DevDatabaseCleanupFunc != nil { + return b.DevDatabaseCleanupFunc() } - - if b.dockertestResource == nil { - return errors.New("found a pool for dev database container but no resource") - } - - return b.dockertestPool.Purge(b.dockertestResource) + return errors.New("no dev database cleanup function found") } diff --git a/internal/db/db.go b/internal/db/db.go index 972baf10bc..9739c222dc 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/golang-migrate/migrate/v4" + "github.com/hashicorp/watchtower/internal/db/migrations" "github.com/jinzhu/gorm" "github.com/ory/dockertest" ) @@ -55,28 +56,57 @@ func Migrate(connectionUrl string, migrationsDirectory string) error { return nil } +// InitDbInDocker initializes the data store within docker or an existing PG_URL +func InitDbInDocker(dialect string) (cleanup func() error, retURL string, err error) { + switch dialect { + case "postgres": + if os.Getenv("PG_URL") != "" { + if err := InitStore(dialect, func() error { return nil }, os.Getenv("PG_URL")); err != nil { + return func() error { return nil }, os.Getenv("PG_URL"), fmt.Errorf("error initializing store: %w", err) + } + return func() error { return nil }, os.Getenv("PG_URL"), nil + } + } + c, url, err := StartDbInDocker(dialect) + if err != nil { + return func() error { return nil }, "", fmt.Errorf("could not start docker: %w", err) + } + if err := InitStore(dialect, c, url); err != nil { + return func() error { return nil }, "", fmt.Errorf("error initializing store: %w", err) + } + return c, url, nil +} + // StartDbInDocker -func StartDbInDocker() (cleanup func() error, retURL string, err error) { +func StartDbInDocker(dialect string) (cleanup func() error, retURL string, err error) { pool, err := dockertest.NewPool("") if err != nil { return func() error { return nil }, "", fmt.Errorf("could not connect to docker: %w", err) } - resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=watchtower"}) + var resource *dockertest.Resource + var url string + switch dialect { + case "postgres": + resource, err = pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=watchtower"}) + url = "postgres://postgres:secret@localhost:%s?sslmode=disable" + default: + panic(fmt.Sprintf("unknown dialect %q", dialect)) + } if err != nil { return func() error { return nil }, "", fmt.Errorf("could not start resource: %w", err) } - c := func() error { + cleanup = func() error { return cleanupDockerResource(pool, resource) } - url := fmt.Sprintf("postgres://postgres:secret@localhost:%s?sslmode=disable", resource.GetPort("5432/tcp")) + url = fmt.Sprintf(url, resource.GetPort("5432/tcp")) if err := pool.Retry(func() error { - db, err := sql.Open("postgres", url) + db, err := sql.Open(dialect, url) if err != nil { - return fmt.Errorf("error opening postgres dev container: %w", err) + return fmt.Errorf("error opening %s dev container: %w", dialect, err) } if err := db.Ping(); err != nil { @@ -87,7 +117,29 @@ func StartDbInDocker() (cleanup func() error, retURL string, err error) { }); err != nil { return func() error { return nil }, "", fmt.Errorf("could not connect to docker: %w", err) } - return c, url, nil + + return cleanup, url, nil +} + +// InitStore will execute the migrations needed to initialize the store for tests +func InitStore(dialect string, cleanup func() error, url string) error { + // run migrations + source, err := migrations.NewMigrationSource(dialect) + if err != nil { + cleanup() + return fmt.Errorf("error creating migration driver: %w", err) + } + m, err := migrate.NewWithSourceInstance("httpfs", source, url) + if err != nil { + cleanup() + return fmt.Errorf("error creating migrations: %w", err) + } + if err := m.Up(); err != nil && err != migrate.ErrNoChange { + cleanup() + return fmt.Errorf("error running migrations: %w", err) + } + + return nil } // cleanupDockerResource will clean up the dockertest resources (postgres) diff --git a/internal/db/db_test.go b/internal/db/db_test.go index 273577d5d6..6e8819964c 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -5,7 +5,7 @@ import ( ) func TestOpen(t *testing.T) { - cleanup, url, err := StartDbInDocker() + cleanup, url, err := StartDbInDocker("postgres") if err != nil { t.Fatal(err) } @@ -56,7 +56,7 @@ func TestOpen(t *testing.T) { } func TestMigrate(t *testing.T) { - cleanup, url, err := StartDbInDocker() + cleanup, url, err := StartDbInDocker("postgres") if err != nil { t.Fatal(err) } diff --git a/internal/db/domains_test.go b/internal/db/domains_test.go index 8e9605b399..014de9e6e8 100644 --- a/internal/db/domains_test.go +++ b/internal/db/domains_test.go @@ -21,7 +21,7 @@ returning id; ` ) - cleanup, conn := TestSetup(t, "migrations/postgres") + cleanup, conn := TestSetup(t, "postgres") defer cleanup() defer conn.Close() diff --git a/internal/db/migrations/driver.go b/internal/db/migrations/driver.go new file mode 100644 index 0000000000..0fbe82298f --- /dev/null +++ b/internal/db/migrations/driver.go @@ -0,0 +1,121 @@ +package migrations + +import ( + "bytes" + "fmt" + "net/http" + "os" + "sort" + "strings" + "time" + + "github.com/golang-migrate/migrate/v4/source" + "github.com/golang-migrate/migrate/v4/source/httpfs" +) + +// migrationDriver satisfies the remaining need of the Driver interface, since +// the package uses PartialDriver under the hood +type migrationDriver struct { + dialect string +} + +// Open returns the given "file" +func (m *migrationDriver) Open(name string) (http.File, error) { + var ff *fakeFile + switch m.dialect { + case "postgres": + ff = postgresMigrations[name] + } + if ff == nil { + return nil, os.ErrNotExist + } + ff.name = strings.TrimPrefix(name, "migrations/") + ff.reader = bytes.NewReader(ff.bytes) + ff.dialect = m.dialect + return ff, nil +} + +// NewMigrationSource creates a source.Driver using httpfs with the given dialect +func NewMigrationSource(dialect string) (source.Driver, error) { + switch dialect { + case "postgres": + default: + return nil, fmt.Errorf("unknown migrations dialect %s", dialect) + } + return httpfs.New(&migrationDriver{dialect}, "migrations") +} + +// fakeFile is used to satisfy the http.File interface +type fakeFile struct { + name string + bytes []byte + reader *bytes.Reader + dialect string +} + +func (f *fakeFile) Read(p []byte) (n int, err error) { + return f.reader.Read(p) +} + +func (f *fakeFile) Seek(offset int64, whence int) (int64, error) { + return f.reader.Seek(offset, whence) +} + +func (f *fakeFile) Close() error { return nil } + +// Readdir returns os.FileInfo values, in sorted order, and eliding the +// migrations "dir" +func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) { + // Get the right map + var migrationsMap map[string]*fakeFile + switch f.dialect { + case "postgres": + migrationsMap = postgresMigrations + default: + return nil, fmt.Errorf("unknown database dialect %s", f.dialect) + } + + // Sort the keys. May not be necessary but feels nice. + keys := make([]string, 0, len(migrationsMap)) + for k := range migrationsMap { + keys = append(keys, k) + } + sort.Strings(keys) + + // Create the slice of fileinfo objects to return + ret := make([]os.FileInfo, 0, len(migrationsMap)) + for _, v := range keys { + // We need "migrations" in the map for the initial Open call but we + // should not return it as part of the "directory"'s "files". + if v == "migrations" { + continue + } + stat, err := migrationsMap[v].Stat() + if err != nil { + return nil, err + } + ret = append(ret, stat) + } + return ret, nil +} + +// Stat returns a new fakeFileInfo object with the necessary bits +func (f *fakeFile) Stat() (os.FileInfo, error) { + return &fakeFileInfo{ + name: f.name, + size: int64(len(f.bytes)), + }, nil +} + +// fakeFileInfo satisfies os.FileInfo but represents our fake "files" +type fakeFileInfo struct { + name string + size int64 +} + +func (f *fakeFileInfo) Name() string { return f.name } +func (f *fakeFileInfo) Size() int64 { return f.size } +func (f *fakeFileInfo) Mode() os.FileMode { return os.ModePerm } +func (f *fakeFileInfo) ModTime() time.Time { return time.Now() } +func (f *fakeFileInfo) IsDir() bool { return false } +func (f *fakeFileInfo) Sys() interface{} { return nil } diff --git a/internal/db/migrations/genmigrations/Makefile b/internal/db/migrations/genmigrations/Makefile new file mode 100644 index 0000000000..a68b29424a --- /dev/null +++ b/internal/db/migrations/genmigrations/Makefile @@ -0,0 +1,8 @@ +# Be sure to place this BEFORE `include` directives, if any. +THIS_FILE := $(lastword $(MAKEFILE_LIST)) + +migrations: + go run -tags genmigrations . + goimports -w ${GEN_BASEPATH}/internal/db/migrations + +.PHONY: migrations diff --git a/internal/db/migrations/genmigrations/generate.go b/internal/db/migrations/genmigrations/generate.go new file mode 100644 index 0000000000..8e6e7d8041 --- /dev/null +++ b/internal/db/migrations/genmigrations/generate.go @@ -0,0 +1,88 @@ +// +build genmigrations + +package main + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "sort" + "strings" + "text/template" +) + +// generate looks for migration sql in a directory for the given dialect and +// applies the templates below to the contents of the files, building up a +// migrations map for the dialect +func generate(dialect string) { + baseDir := os.Getenv("GEN_BASEPATH") + fmt.Sprint("/internal/db/migrations") + dir, err := os.Open(fmt.Sprintf("%s/%s", baseDir, dialect)) + if err != nil { + fmt.Printf("error opening dir with dialect %s: %v\n", dialect, err) + os.Exit(1) + } + names, err := dir.Readdirnames(0) + if err != nil { + fmt.Printf("error reading dir names with dialect %s: %v\n", dialect, err) + os.Exit(1) + } + outBuf := bytes.NewBuffer(nil) + valuesBuf := bytes.NewBuffer(nil) + + sort.Strings(names) + + for _, name := range names { + if !strings.HasSuffix(name, ".sql") { + continue + } + contents, err := ioutil.ReadFile(fmt.Sprintf("%s/%s/%s", baseDir, dialect, name)) + if err != nil { + fmt.Printf("error opening file %s with dialect %s: %v", name, dialect, err) + os.Exit(1) + } + migrationsValueTemplate.Execute(valuesBuf, struct { + Name string + Contents string + }{ + Name: name, + Contents: string(contents), + }) + } + migrationsTemplate.Execute(outBuf, struct { + Type string + Values string + }{ + Type: dialect, + Values: valuesBuf.String(), + }) + + outFile := fmt.Sprintf("%s/%s.gen.go", baseDir, dialect) + if err := ioutil.WriteFile(outFile, outBuf.Bytes(), 0644); err != nil { + fmt.Printf("error writing file %q: %v\n", outFile, err) + os.Exit(1) + } +} + +var migrationsTemplate = template.Must(template.New("").Parse( + `// Code generated by "make migrations"; DO NOT EDIT. +package migrations + +import ( + "bytes" +) + +var {{ .Type }}Migrations = map[string]*fakeFile{ + "migrations": { + name: "migrations", + }, + {{ .Values }} +} +`)) + +var migrationsValueTemplate = template.Must(template.New("").Parse( + `"migrations/{{ .Name }}": { + name: "{{ .Name }}", + bytes: []byte(` + "`\n{{ .Contents }}\n`" + `), + }, +`)) diff --git a/internal/db/migrations/genmigrations/main.go b/internal/db/migrations/genmigrations/main.go new file mode 100644 index 0000000000..0669cf2fbb --- /dev/null +++ b/internal/db/migrations/genmigrations/main.go @@ -0,0 +1,7 @@ +// +build genmigrations + +package main + +func main() { + generate("postgres") +} diff --git a/internal/db/migrations/postgres.gen.go b/internal/db/migrations/postgres.gen.go new file mode 100644 index 0000000000..debdcbea1e --- /dev/null +++ b/internal/db/migrations/postgres.gen.go @@ -0,0 +1,225 @@ +// Code generated by "make migrations"; DO NOT EDIT. +package migrations + +var postgresMigrations = map[string]*fakeFile{ + "migrations": { + name: "migrations", + }, + "migrations/01_oplog.up.sql": { + name: "01_oplog.up.sql", + bytes: []byte(` +CREATE TABLE if not exists oplog_entry ( + id bigint generated always as identity primary key, + create_time timestamp with time zone default current_timestamp, + update_time timestamp with time zone default current_timestamp, + version text NOT NULL, + aggregate_name text NOT NULL, + "data" bytea NOT NULL +); +CREATE TABLE if not exists oplog_ticket ( + id bigint generated always as identity primary key, + create_time timestamp with time zone default current_timestamp, + update_time timestamp with time zone default current_timestamp, + "name" text NOT NULL UNIQUE, + "version" bigint NOT NULL +); +CREATE TABLE if not exists oplog_metadata ( + id bigint generated always as identity primary key, + create_time timestamp with time zone default current_timestamp, + entry_id bigint NOT NULL REFERENCES oplog_entry(id) ON DELETE CASCADE ON UPDATE CASCADE, + "key" text NOT NULL, + value text NULL +); +create index if not exists idx_oplog_metatadata_key on oplog_metadata(key); +create index if not exists idx_oplog_metatadata_value on oplog_metadata(value); +INSERT INTO oplog_ticket (name, version) +values + ('default', 1), + ('iam_scope', 1), + ('iam_user', 1), + ('iam_auth_method', 1), + ('iam_group', 1), + ('iam_group_member_user', 1), + ('iam_role', 1), + ('iam_role_grant', 1), + ('iam_role_group', 1), + ('iam_role_user', 1), + ('db_test_user', 1), + ('db_test_car', 1), + ('db_test_rental', 1); +`), + }, + "migrations/02_domain_types.down.sql": { + name: "02_domain_types.down.sql", + bytes: []byte(` +begin; + +drop domain wt_public_id; + +commit; + +`), + }, + "migrations/02_domain_types.up.sql": { + name: "02_domain_types.up.sql", + bytes: []byte(` +begin; + +create domain wt_public_id as text +check( + length(trim(value)) > 10 +); +comment on domain wt_public_id is +'Random ID generated with github.com/hashicorp/vault/sdk/helper/base62'; + +commit; + +`), + }, + "migrations/03_db.down.sql": { + name: "03_db.down.sql", + bytes: []byte(` +drop table if exists db_test_user; +drop table if exists db_test_car; +drop table if exists db_test_rental; +`), + }, + "migrations/03_db.up.sql": { + name: "03_db.up.sql", + bytes: []byte(` +-- create test tables used in the unit tests for the internal/db package +-- these tables (db_test_user, db_test_car, db_test_rental) are not part +-- of the Watchtower domain model... they are simply used for testing the internal/db package +CREATE TABLE if not exists db_test_user ( + id bigint generated always as identity primary key, + create_time timestamp with time zone default current_timestamp, + update_time timestamp with time zone default current_timestamp, + public_id text NOT NULL UNIQUE, + name text UNIQUE, + phone_number text, + email text +); +CREATE TABLE if not exists db_test_car ( + id bigint generated always as identity primary key, + create_time timestamp with time zone default current_timestamp, + update_time timestamp with time zone default current_timestamp, + public_id text NOT NULL UNIQUE, + name text UNIQUE, + model text, + mpg smallint +); +CREATE TABLE if not exists db_test_rental ( + id bigint generated always as identity primary key, + create_time timestamp with time zone default current_timestamp, + update_time timestamp with time zone default current_timestamp, + public_id text NOT NULL UNIQUE, + name text UNIQUE, + user_id bigint not null REFERENCES db_test_user(id), + car_id bigint not null REFERENCES db_test_car(id) +); +`), + }, + "migrations/04_iam.down.sql": { + name: "04_iam.down.sql", + bytes: []byte(` +BEGIN; + +drop table if exists iam_scope CASCADE; +drop trigger if exists iam_scope_insert; +drop function if exists iam_sub_scopes_func; + +COMMIT; +`), + }, + "migrations/04_iam.up.sql": { + name: "04_iam.up.sql", + bytes: []byte(` +BEGIN; + +CREATE TABLE if not exists iam_scope_type_enm ( + string text NOT NULL primary key CHECK(string IN ('unknown', 'organization', 'project')) +); +INSERT INTO iam_scope_type_enm (string) +values + ('unknown'), + ('organization'), + ('project'); + + +CREATE TABLE if not exists iam_scope ( + public_id wt_public_id primary key, + create_time timestamp with time zone default current_timestamp, + update_time timestamp with time zone default current_timestamp, + name text, + type text NOT NULL REFERENCES iam_scope_type_enm(string) CHECK( + ( + type = 'organization' + and parent_id = NULL + ) + or ( + type = 'project' + and parent_id IS NOT NULL + ) + ), + description text, + parent_id text REFERENCES iam_scope(public_id) ON DELETE CASCADE ON UPDATE CASCADE + ); +create table if not exists iam_scope_organization ( + scope_id wt_public_id NOT NULL UNIQUE REFERENCES iam_scope(public_id) ON DELETE CASCADE ON UPDATE CASCADE, + name text UNIQUE, + primary key(scope_id) + ); +create table if not exists iam_scope_project ( + scope_id wt_public_id NOT NULL REFERENCES iam_scope(public_id) ON DELETE CASCADE ON UPDATE CASCADE, + parent_id wt_public_id NOT NULL REFERENCES iam_scope_organization(scope_id) ON DELETE CASCADE ON UPDATE CASCADE, + name text, + unique(parent_id, name), + primary key(scope_id, parent_id) + ); + + +CREATE + OR REPLACE FUNCTION iam_sub_scopes_func() RETURNS TRIGGER +SET SCHEMA + 'public' LANGUAGE plpgsql AS $$ DECLARE parent_type INT; +BEGIN IF new.type = 'organization' THEN +insert into iam_scope_organization (scope_id, name) +values + (new.public_id, new.name); +return NEW; +END IF; +IF new.type = 'project' THEN +insert into iam_scope_project (scope_id, parent_id, name) +values + (new.public_id, new.parent_id, new.name); +return NEW; +END IF; +RAISE EXCEPTION 'unknown scope type'; +END; +$$; + + +CREATE TRIGGER iam_scope_insert +AFTER +insert ON iam_scope FOR EACH ROW EXECUTE PROCEDURE iam_sub_scopes_func(); + + +CREATE + OR REPLACE FUNCTION iam_immutable_scope_type_func() RETURNS TRIGGER +SET SCHEMA + 'public' LANGUAGE plpgsql AS $$ DECLARE parent_type INT; +BEGIN IF new.type != old.type THEN +RAISE EXCEPTION 'scope type cannot be updated'; +END IF; +return NEW; +END; +$$; + +CREATE TRIGGER iam_scope_update +BEFORE +update ON iam_scope FOR EACH ROW EXECUTE PROCEDURE iam_immutable_scope_type_func(); + +COMMIT; +`), + }, +} diff --git a/internal/db/read_writer_test.go b/internal/db/read_writer_test.go index 90ae589b8b..04ba3e1ec9 100644 --- a/internal/db/read_writer_test.go +++ b/internal/db/read_writer_test.go @@ -15,7 +15,7 @@ import ( func TestGormReadWriter_Update(t *testing.T) { // intentionally not run with t.Parallel so we don't need to use DoTx for the Update tests - cleanup, db := TestSetup(t, "migrations/postgres") + cleanup, db := TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer db.Close() @@ -163,7 +163,7 @@ func TestGormReadWriter_Update(t *testing.T) { func TestGormReadWriter_Create(t *testing.T) { // intentionally not run with t.Parallel so we don't need to use DoTx for the Create tests - cleanup, db := TestSetup(t, "migrations/postgres") + cleanup, db := TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer db.Close() @@ -271,7 +271,7 @@ func TestGormReadWriter_Create(t *testing.T) { func TestGormReadWriter_LookupByName(t *testing.T) { t.Parallel() - cleanup, db := TestSetup(t, "migrations/postgres") + cleanup, db := TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer db.Close() @@ -326,7 +326,7 @@ func TestGormReadWriter_LookupByName(t *testing.T) { func TestGormReadWriter_LookupByPublicId(t *testing.T) { t.Parallel() - cleanup, db := TestSetup(t, "migrations/postgres") + cleanup, db := TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer db.Close() @@ -381,7 +381,7 @@ func TestGormReadWriter_LookupByPublicId(t *testing.T) { func TestGormReadWriter_LookupWhere(t *testing.T) { t.Parallel() - cleanup, db := TestSetup(t, "migrations/postgres") + cleanup, db := TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer db.Close() @@ -431,7 +431,7 @@ func TestGormReadWriter_LookupWhere(t *testing.T) { func TestGormReadWriter_SearchWhere(t *testing.T) { t.Parallel() - cleanup, db := TestSetup(t, "migrations/postgres") + cleanup, db := TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer db.Close() @@ -481,7 +481,7 @@ func TestGormReadWriter_SearchWhere(t *testing.T) { func TestGormReadWriter_DB(t *testing.T) { t.Parallel() - cleanup, db := TestSetup(t, "migrations/postgres") + cleanup, db := TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer db.Close() @@ -504,7 +504,7 @@ func TestGormReadWriter_DB(t *testing.T) { func TestGormReadWriter_DoTx(t *testing.T) { t.Parallel() - cleanup, db := TestSetup(t, "migrations/postgres") + cleanup, db := TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer db.Close() @@ -604,7 +604,7 @@ func TestGormReadWriter_DoTx(t *testing.T) { func TestGormReadWriter_Delete(t *testing.T) { // intentionally not run with t.Parallel so we don't need to use DoTx for the Create tests - cleanup, db := TestSetup(t, "migrations/postgres") + cleanup, db := TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer db.Close() @@ -781,7 +781,7 @@ func TestGormReadWriter_Delete(t *testing.T) { func TestGormReadWriter_ScanRows(t *testing.T) { t.Parallel() - cleanup, db := TestSetup(t, "migrations/postgres") + cleanup, db := TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer db.Close() diff --git a/internal/db/testing.go b/internal/db/testing.go index 02af8666a3..6b7af4c216 100644 --- a/internal/db/testing.go +++ b/internal/db/testing.go @@ -2,11 +2,8 @@ package db import ( "crypto/rand" - "fmt" - "os" "testing" - "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" wrapping "github.com/hashicorp/go-kms-wrapping" @@ -15,19 +12,15 @@ import ( ) // setup the tests (initialize the database one-time and intialized testDatabaseURL) -func TestSetup(t *testing.T, migrationsDirectory string) (func() error, *gorm.DB) { - if _, err := os.Stat(migrationsDirectory); os.IsNotExist(err) { - t.Fatal("error migrationsDirectory does not exist") - } - +func TestSetup(t *testing.T, dialect string) (func() error, *gorm.DB) { cleanup := func() error { return nil } var url string var err error - cleanup, url, err = initDbInDocker(t, migrationsDirectory) + cleanup, url, err = InitDbInDocker(dialect) if err != nil { t.Fatal(err) } - db, err := gorm.Open("postgres", url) + db, err := gorm.Open(dialect, url) if err != nil { t.Fatal(err) } @@ -51,30 +44,3 @@ func TestWrapper(t *testing.T) wrapping.Wrapper { return root } -// initDbInDocker initializes postgres within dockertest for the unit tests -func initDbInDocker(t *testing.T, migrationsDirectory string) (cleanup func() error, retURL string, err error) { - if os.Getenv("PG_URL") != "" { - TestInitStore(t, func() error { return nil }, os.Getenv("PG_URL"), migrationsDirectory) - return func() error { return nil }, os.Getenv("PG_URL"), nil - } - c, url, err := StartDbInDocker() - if err != nil { - return func() error { return nil }, "", fmt.Errorf("could not start docker: %w", err) - } - TestInitStore(t, c, url, migrationsDirectory) - return c, url, nil -} - -// TestInitStore will execute the migrations needed to initialize the store for tests -func TestInitStore(t *testing.T, cleanup func() error, url string, migrationsDirectory string) { - // run migrations - m, err := migrate.New(fmt.Sprintf("file://%s", migrationsDirectory), url) - if err != nil { - cleanup() - t.Fatalf("Error creating migrations: %s", err) - } - if err := m.Up(); err != nil && err != migrate.ErrNoChange { - cleanup() - t.Fatalf("Error running migrations: %s", err) - } -} diff --git a/internal/db/testing_test.go b/internal/db/testing_test.go index db0caf29b6..03b1833711 100644 --- a/internal/db/testing_test.go +++ b/internal/db/testing_test.go @@ -6,7 +6,7 @@ import ( func Test_Utils(t *testing.T) { t.Parallel() - cleanup, conn := TestSetup(t, "migrations/postgres") + cleanup, conn := TestSetup(t, "postgres") defer cleanup() defer conn.Close() t.Run("nothing", func(t *testing.T) { diff --git a/internal/iam/repository_scope_test.go b/internal/iam/repository_scope_test.go index 31c13fcc09..d562719f28 100644 --- a/internal/iam/repository_scope_test.go +++ b/internal/iam/repository_scope_test.go @@ -13,7 +13,7 @@ import ( func Test_Repository_CreateScope(t *testing.T) { t.Parallel() - cleanup, conn := db.TestSetup(t, "../db/migrations/postgres") + cleanup, conn := db.TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer conn.Close() @@ -52,7 +52,7 @@ func Test_Repository_CreateScope(t *testing.T) { func Test_Repository_UpdateScope(t *testing.T) { t.Parallel() - cleanup, conn := db.TestSetup(t, "../db/migrations/postgres") + cleanup, conn := db.TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer conn.Close() diff --git a/internal/iam/repository_test.go b/internal/iam/repository_test.go index 100189bfc6..d9841bbfa9 100644 --- a/internal/iam/repository_test.go +++ b/internal/iam/repository_test.go @@ -15,7 +15,7 @@ import ( func TestNewRepository(t *testing.T) { t.Parallel() - cleanup, conn := db.TestSetup(t, "../db/migrations/postgres") + cleanup, conn := db.TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer conn.Close() @@ -100,7 +100,7 @@ func TestNewRepository(t *testing.T) { } func Test_Repository_create(t *testing.T) { t.Parallel() - cleanup, conn := db.TestSetup(t, "../db/migrations/postgres") + cleanup, conn := db.TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer conn.Close() @@ -149,7 +149,7 @@ func Test_Repository_create(t *testing.T) { func Test_dbRepository_update(t *testing.T) { t.Parallel() - cleanup, conn := db.TestSetup(t, "../db/migrations/postgres") + cleanup, conn := db.TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer conn.Close() diff --git a/internal/iam/scope_test.go b/internal/iam/scope_test.go index 42e3f56167..2bd2f405e3 100644 --- a/internal/iam/scope_test.go +++ b/internal/iam/scope_test.go @@ -12,7 +12,7 @@ import ( func Test_NewScope(t *testing.T) { t.Parallel() - cleanup, conn := db.TestSetup(t, "../db/migrations/postgres") + cleanup, conn := db.TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer conn.Close() @@ -49,7 +49,7 @@ func Test_NewScope(t *testing.T) { } func Test_ScopeCreate(t *testing.T) { t.Parallel() - cleanup, conn := db.TestSetup(t, "../db/migrations/postgres") + cleanup, conn := db.TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer conn.Close() @@ -89,7 +89,7 @@ func Test_ScopeCreate(t *testing.T) { func Test_ScopeUpdate(t *testing.T) { t.Parallel() - cleanup, conn := db.TestSetup(t, "../db/migrations/postgres") + cleanup, conn := db.TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer conn.Close() @@ -125,7 +125,7 @@ func Test_ScopeUpdate(t *testing.T) { } func Test_ScopeGetScope(t *testing.T) { t.Parallel() - cleanup, conn := db.TestSetup(t, "../db/migrations/postgres") + cleanup, conn := db.TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer conn.Close() @@ -176,7 +176,7 @@ func TestScope_ResourceType(t *testing.T) { func TestScope_Clone(t *testing.T) { t.Parallel() - cleanup, conn := db.TestSetup(t, "../db/migrations/postgres") + cleanup, conn := db.TestSetup(t, "postgres") defer cleanup() assert := assert.New(t) defer conn.Close()