diff --git a/go.mod b/go.mod index 05ffaa5a7c..b514b66250 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,7 @@ require ( github.com/hashicorp/go-sockaddr v1.0.2 github.com/hashicorp/go-uuid v1.0.2 github.com/hashicorp/hcl v1.0.0 - github.com/hashicorp/vault v1.2.1-0.20200519221912-a8c2591d3641 + github.com/hashicorp/vault v1.2.1-0.20200521015612-812a92b26b19 github.com/hashicorp/vault/sdk v0.1.14-0.20200519221912-a8c2591d3641 github.com/jackc/pgx/v4 v4.6.0 github.com/jinzhu/gorm v1.9.12 diff --git a/go.sum b/go.sum index 9463fc37de..ca9e2afffa 100644 --- a/go.sum +++ b/go.sum @@ -562,13 +562,12 @@ github.com/grpc-ecosystem/grpc-gateway v1.14.5 h1:aiLxiiVzAXb7wb3lAmubA69IokWOoU github.com/grpc-ecosystem/grpc-gateway v1.14.5/go.mod h1:UJ0EZAp832vCd54Wev9N1BMKEyvcZ5+IM0AwDrnlkEc= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= -github.com/hashicorp/consul-template v0.22.0 h1:ti5cqAekOeMfFYLJCjlPtKGwBcqwVxoZO/Y2vctwuUE= -github.com/hashicorp/consul-template v0.22.0/go.mod h1:lHrykBIcPobCuEcIMLJryKxDyk2lUMnQWmffOEONH0k= -github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q= -github.com/hashicorp/consul/api v1.2.1-0.20200128105449-6681be918a6e h1:vOqdnsq53winzJDN6RTQe9n9g87S595PNsdwKyBWXRM= -github.com/hashicorp/consul/api v1.2.1-0.20200128105449-6681be918a6e/go.mod h1:ztzLK20HA5O27oTf2j/wbNgq8qj/crN8xsSx7pzX0sc= +github.com/hashicorp/consul-template v0.25.0 h1:wsnv4jSqBIVzlg6U0wNg+ePzfrsF3Vi9MqIqDEUrg9U= +github.com/hashicorp/consul-template v0.25.0/go.mod h1:/vUsrJvDuuQHcxEw0zik+YXTS7ZKWZjQeaQhshBmfH0= +github.com/hashicorp/consul/api v1.4.0 h1:jfESivXnO5uLdH650JU/6AnjRoHrLhULq0FnC3Kp9EY= +github.com/hashicorp/consul/api v1.4.0/go.mod h1:xc8u05kyMa3Wjr9eEAsIAo3dg8+LywT5E/Cl7cNS5nU= github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= -github.com/hashicorp/consul/sdk v0.2.0/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= +github.com/hashicorp/consul/sdk v0.4.0/go.mod h1:fY08Y9z5SvJqevyZNy6WWPXiG3KwBPAvlcdx16zZ0fM= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-alpnmux v0.0.0-20200513011953-0293f5d23c31 h1:pxqI71/0R1WIASjQEJ9W9skCKYiREEkRoXFvHCZH1pg= @@ -657,8 +656,8 @@ github.com/hashicorp/raft-snapshot v1.0.2-0.20190827162939-8117efcc5aab/go.mod h github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= github.com/hashicorp/serf v0.8.3 h1:MWYcmct5EtKz0efYooPcL0yNkem+7kWxqXDi/UIh+8k= github.com/hashicorp/serf v0.8.3/go.mod h1:UpNcs7fFbpKIyZaUuSW6EPiH+eZC7OuyFD+wc1oal+k= -github.com/hashicorp/vault v1.2.1-0.20200519221912-a8c2591d3641 h1:sFzGh4bThye9dcV3XhO6xOtWlhXY73cgfod/B3eqdf0= -github.com/hashicorp/vault v1.2.1-0.20200519221912-a8c2591d3641/go.mod h1:AxZRQLEuX4UQt4GIITGJE/s/dpQJjwQnPfufvchqxFM= +github.com/hashicorp/vault v1.2.1-0.20200521015612-812a92b26b19 h1:UnyS8rJUziC7TLs/0blpiCpbnq3l3fwzkpZG+VLLUWI= +github.com/hashicorp/vault v1.2.1-0.20200521015612-812a92b26b19/go.mod h1:m42VsHQcRRfOQWFUF7eDOjYE18W+ZH2KgQzJezt7YKo= github.com/hashicorp/vault-plugin-auth-alicloud v0.5.5 h1:JYf3VYpKs7mOdtcwZWi73S82oXrC/JR7uoPVUd8c4Hk= github.com/hashicorp/vault-plugin-auth-alicloud v0.5.5/go.mod h1:sQ+VNwPQlemgXHXikYH6onfH9gPwDZ1GUVRLz0ZvHx8= github.com/hashicorp/vault-plugin-auth-azure v0.5.5 h1:kN79ai+aMVU9hUmwscHjmweW2fGa8V/t+ScIchPZGrk= @@ -682,8 +681,7 @@ github.com/hashicorp/vault-plugin-database-elasticsearch v0.5.4 h1:YE4qndazWmYGp github.com/hashicorp/vault-plugin-database-elasticsearch v0.5.4/go.mod h1:QjGrrxcRXv/4XkEZAlM0VMZEa3uxKAICFqDj27FP/48= github.com/hashicorp/vault-plugin-database-mongodbatlas v0.1.1 h1:fA6cFH8lIPH2M4KNTEzf1bpc6Tbyy5ZvoYP8H/TI9ts= github.com/hashicorp/vault-plugin-database-mongodbatlas v0.1.1/go.mod h1:MP3kfr0N+7miOTZFwKv952b9VkXM4S2Q6YtQCiNKWq8= -github.com/hashicorp/vault-plugin-secrets-ad v0.6.5 h1:wrHzXSD6qmKvkuHaQn+BNj89+HGhMNchxAckGnd7YTc= -github.com/hashicorp/vault-plugin-secrets-ad v0.6.5/go.mod h1:kk98nB+cwDbt3I7UGQq3ota7+eHZrGSTQZfSRGpluvA= +github.com/hashicorp/vault-plugin-secrets-ad v0.6.4-beta1.0.20200518124111-3dceeb3ce90e/go.mod h1:SCsKcChP8yrtOHXOeTD7oRk0oflj3IxA9y9zTOGtQ8s= github.com/hashicorp/vault-plugin-secrets-alicloud v0.5.5 h1:BOOtSls+BQ1EtPmpE9LoqZztsEZ1fRWVSkHWtRIrCB4= github.com/hashicorp/vault-plugin-secrets-alicloud v0.5.5/go.mod h1:gAoReoUpBHaBwkxQqTK7FY8nQC0MuaZHLiW5WOSny5g= github.com/hashicorp/vault-plugin-secrets-azure v0.5.6 h1:4PgQ5rCT29wW5PMyebEhPkEYuR5s+SnInuZz3x2cP50= @@ -1393,7 +1391,6 @@ golang.org/x/sys v0.0.0-20190616124812-15dcb6c0061f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190730183949-1393eb018365/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/internal/cmd/config/config.go b/internal/cmd/config/config.go index 7ed919b1a0..ade410ca0c 100644 --- a/internal/cmd/config/config.go +++ b/internal/cmd/config/config.go @@ -40,6 +40,8 @@ listener "tcp" { tls_disable = true proxy_protocol_behavior = "allow_authorized" proxy_protocol_authorized_addrs = "127.0.0.1" + cors_enabled = true + cors_allowed_origins = ["*"] } listener "tcp" { diff --git a/internal/cmd/config/config_test.go b/internal/cmd/config/config_test.go new file mode 100644 index 0000000000..02a2405b37 --- /dev/null +++ b/internal/cmd/config/config_test.go @@ -0,0 +1,78 @@ +package config + +import ( + "testing" + "time" + + "github.com/hashicorp/go-sockaddr" + "github.com/hashicorp/vault/internalshared/configutil" + "github.com/stretchr/testify/assert" +) + +func TestDevController(t *testing.T) { + actual, err := DevController() + if err != nil { + t.Fatal(err) + } + + addr, err := sockaddr.NewIPAddr("127.0.0.1") + if err != nil { + t.Fatal(err) + } + + exp := &Config{ + SharedConfig: &configutil.SharedConfig{ + DisableMlock: true, + Listeners: []*configutil.Listener{ + { + Type: "tcp", + Purpose: []string{"api"}, + TLSDisable: true, + ProxyProtocolBehavior: "allow_authorized", + ProxyProtocolAuthorizedAddrs: []*sockaddr.SockAddrMarshaler{ + {SockAddr: addr}, + }, + CorsEnabled: true, + CorsAllowedOrigins: []string{"*"}, + }, + { + Type: "tcp", + Purpose: []string{"cluster"}, + TLSDisable: true, + ProxyProtocolBehavior: "allow_authorized", + ProxyProtocolAuthorizedAddrs: []*sockaddr.SockAddrMarshaler{ + {SockAddr: addr}, + }, + }, + }, + Seals: []*configutil.KMS{ + { + Type: "aead", + Purpose: []string{"controller"}, + Config: map[string]string{ + "aead_type": "aes-gcm", + }, + }, + { + Type: "aead", + Purpose: []string{"worker-auth"}, + Config: map[string]string{ + "aead_type": "aes-gcm", + }, + }, + }, + Telemetry: &configutil.Telemetry{ + DisableHostname: true, + PrometheusRetentionTime: time.Hour * 24, + }, + }, + DevController: true, + } + + exp.Listeners[0].RawConfig = actual.Listeners[0].RawConfig + exp.Listeners[1].RawConfig = actual.Listeners[1].RawConfig + exp.Seals[0].Config["key"] = actual.Seals[0].Config["key"] + exp.Seals[1].Config["key"] = actual.Seals[1].Config["key"] + + assert.Equal(t, exp, actual) +} diff --git a/internal/servers/controller/cors_test.go b/internal/servers/controller/cors_test.go new file mode 100644 index 0000000000..7265d17527 --- /dev/null +++ b/internal/servers/controller/cors_test.go @@ -0,0 +1,220 @@ +package controller + +import ( + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/hashicorp/watchtower/internal/cmd/config" +) + +const corsTestConfig = ` +disable_mlock = true + +telemetry { + prometheus_retention_time = "24h" + disable_hostname = true +} + +kms "aead" { + purpose = "controller" + aead_type = "aes-gcm" + key = "09iqFxRJNYsl/b8CQxjnGw==" +} + +kms "aead" { + purpose = "worker-auth" + aead_type = "aes-gcm" + key = "09iqFxRJNYsl/b8CQxjnGw==" +} + +listener "tcp" { + purpose = "api" + tls_disable = true + cors_enabled = false +} + +listener "tcp" { + purpose = "api" + tls_disable = true + cors_enabled = true + cors_allowed_origins = [] +} + +listener "tcp" { + purpose = "api" + tls_disable = true + cors_enabled = true + cors_allowed_origins = ["foobar.com", "barfoo.com"] +} + +listener "tcp" { + purpose = "api" + tls_disable = true + cors_enabled = true + cors_allowed_origins = ["*"] + cors_allowed_headers = ["x-foobar"] +} +` + +func TestHandler_CORS(t *testing.T) { + cfg, err := config.Parse(corsTestConfig) + if err != nil { + t.Fatal(err) + } + tc := NewTestController(t, &TestControllerOpts{ + Config: cfg, + DefaultOrgId: "o_1234567890", + }) + defer tc.Shutdown() + + cases := []struct { + name string + method string + origin string + code int + acrmHeader string + allowedHeader string + listenerNum int + }{ + { + "disabled no origin", + http.MethodPost, + "", + http.StatusOK, + "", + "", + 1, + }, + { + "disabled with origin", + http.MethodPost, + "foobar.com", + http.StatusOK, + "", + "", + 1, + }, + { + "enabled with no allowed origins and no origin defined", + http.MethodPost, + "", + http.StatusOK, + "", + "", + 2, + }, + { + "enabled with no allowed origins and origin defined", + http.MethodPost, + "foobar.com", + http.StatusForbidden, + "", + "", + 2, + }, + { + "enabled with allowed origins and no origin defined", + http.MethodPost, + "", + http.StatusOK, + "", + "", + 3, + }, + { + "enabled with allowed origins and bad origin defined", + http.MethodPost, + "flubber.com", + http.StatusForbidden, + "", + "", + 3, + }, + { + "enabled with allowed origins and good origin defined", + http.MethodPost, + "barfoo.com", + http.StatusOK, + "", + "", + 3, + }, + { + "enabled with wildcard origins and no origin defined", + http.MethodPost, + "", + http.StatusOK, + "", + "", + 4, + }, + { + "enabled with wildcard origins and origin defined", + http.MethodPost, + "flubber.com", + http.StatusOK, + "", + "", + 4, + }, + { + "wildcard origins with method list and good method", + http.MethodOptions, + "flubber.com", + http.StatusNoContent, + "DELETE", + "", + 4, + }, + { + "wildcard origins with method list and bad method", + http.MethodOptions, + "flubber.com", + http.StatusMethodNotAllowed, + "BADSTUFF", + "X-Foobar", + 4, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + // Create a client with the right address + client := tc.Client() + client.SetAddr(tc.ApiAddrs()[c.listenerNum-1]) + + // Create the request + req, err := client.NewRequest(tc.Context(), c.method, "orgs/o_1234567890/projects", nil) + assert.NoError(t, err) + + // Append headers + if c.origin != "" { + req.Header.Add("Origin", c.origin) + } + if c.acrmHeader != "" { + req.Header.Add("Access-Control-Request-Method", c.acrmHeader) + } + + // Run the request, do basic checks + resp, err := client.Do(req) + assert.NoError(t, err) + assert.Equal(t, c.code, resp.HttpResponse().StatusCode) + + // If options and we expect it to be successful, run some checks + if req.Method == http.MethodOptions && c.code == http.StatusNoContent { + assert.Equal(t, fmt.Sprintf("%s, %s, %s, %s, %s", http.MethodDelete, http.MethodGet, http.MethodOptions, http.MethodPost, http.MethodPatch), resp.HttpResponse().Header.Get("Access-Control-Allow-Methods")) + assert.Equal(t, fmt.Sprintf("%s, %s, %s, %s", "Content-Type", "X-Requested-With", "Authorization", "X-Foobar"), resp.HttpResponse().Header.Get("Access-Control-Allow-Headers")) + assert.Equal(t, "300", resp.HttpResponse().Header.Get("Access-Control-Max-Age")) + } + + // If origin was set and we expect it to be successful, run some more checks + if c.origin != "" && c.code == http.StatusOK && c.listenerNum > 1 { + assert.Equal(t, c.origin, resp.HttpResponse().Header.Get("Access-Control-Allow-Origin")) + assert.Equal(t, "Origin", resp.HttpResponse().Header.Get("Vary")) + } + }) + } +} diff --git a/internal/servers/controller/handler.go b/internal/servers/controller/handler.go index 73f0d04d8d..9511ed9b55 100644 --- a/internal/servers/controller/handler.go +++ b/internal/servers/controller/handler.go @@ -2,6 +2,7 @@ package controller import ( "context" + "encoding/json" "net/http" "path" "strings" @@ -9,6 +10,8 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/runtime" "github.com/hashicorp/vault/internalshared/configutil" + "github.com/hashicorp/vault/sdk/helper/strutil" + "github.com/hashicorp/watchtower/api" "github.com/hashicorp/watchtower/globals" "github.com/hashicorp/watchtower/internal/gen/controller/api/services" "github.com/hashicorp/watchtower/internal/servers/controller/handlers/host_catalogs" @@ -29,9 +32,10 @@ func (c *Controller) handler(props HandlerProperties) http.Handler { mux.Handle("/v1/", handleGrpcGateway(c)) - genericWrappedHandler := wrapGenericHandler(mux, c, props) + corsWrappedHandler := wrapHandlerWithCors(mux, props) + commonWrappedHandler := wrapHandlerWithCommonFuncs(corsWrappedHandler, c, props) - return genericWrappedHandler + return commonWrappedHandler } func handleGrpcGateway(c *Controller) http.Handler { @@ -47,7 +51,7 @@ func handleGrpcGateway(c *Controller) http.Handler { return mux } -func wrapGenericHandler(h http.Handler, c *Controller, props HandlerProperties) http.Handler { +func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProperties) http.Handler { var maxRequestDuration time.Duration var maxRequestSize int64 if props.ListenerConfig != nil { @@ -94,6 +98,86 @@ func wrapGenericHandler(h http.Handler, c *Controller, props HandlerProperties) }) } +func wrapHandlerWithCors(h http.Handler, props HandlerProperties) http.Handler { + allowedMethods := []string{ + http.MethodDelete, + http.MethodGet, + http.MethodOptions, + http.MethodPost, + http.MethodPatch, + } + + allowedOrigins := props.ListenerConfig.CorsAllowedOrigins + + allowedHeaders := append([]string{ + "Content-Type", + "X-Requested-With", + "Authorization", + }, props.ListenerConfig.CorsAllowedHeaders...) + + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if !props.ListenerConfig.CorsEnabled { + h.ServeHTTP(w, req) + return + } + + origin := req.Header.Get("Origin") + + if origin == "" { + // Serve directly + h.ServeHTTP(w, req) + return + } + + // Check origin + var valid bool + switch { + case len(allowedOrigins) == 0: + // not valid + + case len(allowedOrigins) == 1 && allowedOrigins[0] == "*": + valid = true + + default: + valid = strutil.StrListContains(allowedOrigins, origin) + } + if !valid { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + + err := &api.Error{ + Status: api.Int(http.StatusForbidden), + Code: api.String("origin forbidden"), + } + + enc := json.NewEncoder(w) + enc.Encode(err) + return + } + + if req.Method == http.MethodOptions && + !strutil.StrListContains(allowedMethods, req.Header.Get("Access-Control-Request-Method")) { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Vary", "Origin") + + // Apply headers for preflight requests + if req.Method == http.MethodOptions { + w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", ")) + w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", ")) + w.Header().Set("Access-Control-Max-Age", "300") + w.WriteHeader(http.StatusNoContent) + return + } + + h.ServeHTTP(w, req) + return + }) +} + /* func WrapForwardedForHandler(h http.Handler, authorizedAddrs []*sockaddr.SockAddrMarshaler, rejectNotPresent, rejectNonAuthz bool, hopSkips int) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/servers/controller/handler_test.go b/internal/servers/controller/handler_test.go index 1cdb0dbb24..323fe3eef9 100644 --- a/internal/servers/controller/handler_test.go +++ b/internal/servers/controller/handler_test.go @@ -48,3 +48,4 @@ func TestHandleGrpcGateway(t *testing.T) { }) } } + diff --git a/internal/servers/controller/testing.go b/internal/servers/controller/testing.go index 894bdd5e7f..b73c0b1d83 100644 --- a/internal/servers/controller/testing.go +++ b/internal/servers/controller/testing.go @@ -19,7 +19,7 @@ type TestController struct { b *base.Server c *Controller t *testing.T - addr string // The address the Controller API is listening on + addrs []string // The address the Controller API is listening on client *api.Client ctx context.Context cancel context.CancelFunc @@ -42,27 +42,23 @@ func (tc *TestController) Cancel() { tc.cancel() } -func (tc *TestController) ApiAddress() string { - if tc.addr != "" { - return tc.addr +func (tc *TestController) ApiAddrs() []string { + if tc.addrs != nil { + return tc.addrs } - var apiLn *base.ServerListener + for _, listener := range tc.b.Listeners { if listener.Config.Purpose[0] == "api" { - apiLn = listener - break + tcpAddr, ok := listener.Mux.Addr().(*net.TCPAddr) + if !ok { + tc.t.Fatal("could not parse address as a TCP addr") + } + addr := fmt.Sprintf("http://%s:%d", tcpAddr.IP.String(), tcpAddr.Port) + tc.addrs = append(tc.addrs, addr) } } - if apiLn == nil { - tc.t.Fatal("could not find api listener") - } - tcpAddr, ok := apiLn.Mux.Addr().(*net.TCPAddr) - if !ok { - tc.t.Fatal("could not parse address as a TCP addr") - } - tc.addr = fmt.Sprintf("http://%s:%d", tcpAddr.IP.String(), tcpAddr.Port) - return tc.addr + return tc.addrs } func (tc *TestController) buildClient() { @@ -70,7 +66,11 @@ func (tc *TestController) buildClient() { if err != nil { tc.t.Fatal(fmt.Errorf("error creating client: %w", err)) } - if err := client.SetAddr(tc.ApiAddress()); err != nil { + apiAddrs := tc.ApiAddrs() + if len(apiAddrs) == 0 { + tc.t.Fatal("no API addresses found") + } + if err := client.SetAddr(apiAddrs[0]); err != nil { tc.t.Fatal(fmt.Errorf("error setting client address: %w", err)) } diff --git a/testing/controller/controller.go b/testing/controller/controller.go index 6bfdc8a30c..43f16b965b 100644 --- a/testing/controller/controller.go +++ b/testing/controller/controller.go @@ -91,11 +91,15 @@ func WithDefaultOrgId(id string) Option { // NewTestController blocks until a new TestController is created, returns the url for the TestController and a function // that can be called to tear down the controller after it has been used for testing. -func NewTestController(t *testing.T, opt ...Option) (string, func()) { +func NewTestController(t *testing.T, opt ...Option) *TestController { conf, err := getOpts(opt...) if err != nil { t.Fatalf("Couldn't create TestController: %v", err) } tc := controller.NewTestController(t, conf) - return tc.ApiAddress(), tc.Shutdown + return &TestController{TestController: tc} +} + +type TestController struct { + *controller.TestController }