CORS support (#73)

This adds per-listener CORS support. Logic is ported from Vault, along with a fully new test suite.
pull/76/head
Jeff Mitchell 6 years ago committed by GitHub
parent 66978d637a
commit fafdcbbfa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -27,7 +27,7 @@ require (
github.com/hashicorp/go-sockaddr v1.0.2
github.com/hashicorp/go-uuid v1.0.2
github.com/hashicorp/hcl v1.0.0
github.com/hashicorp/vault v1.2.1-0.20200519221912-a8c2591d3641
github.com/hashicorp/vault v1.2.1-0.20200521015612-812a92b26b19
github.com/hashicorp/vault/sdk v0.1.14-0.20200519221912-a8c2591d3641
github.com/jackc/pgx/v4 v4.6.0
github.com/jinzhu/gorm v1.9.12

@ -562,13 +562,12 @@ github.com/grpc-ecosystem/grpc-gateway v1.14.5 h1:aiLxiiVzAXb7wb3lAmubA69IokWOoU
github.com/grpc-ecosystem/grpc-gateway v1.14.5/go.mod h1:UJ0EZAp832vCd54Wev9N1BMKEyvcZ5+IM0AwDrnlkEc=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
github.com/hashicorp/consul-template v0.22.0 h1:ti5cqAekOeMfFYLJCjlPtKGwBcqwVxoZO/Y2vctwuUE=
github.com/hashicorp/consul-template v0.22.0/go.mod h1:lHrykBIcPobCuEcIMLJryKxDyk2lUMnQWmffOEONH0k=
github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q=
github.com/hashicorp/consul/api v1.2.1-0.20200128105449-6681be918a6e h1:vOqdnsq53winzJDN6RTQe9n9g87S595PNsdwKyBWXRM=
github.com/hashicorp/consul/api v1.2.1-0.20200128105449-6681be918a6e/go.mod h1:ztzLK20HA5O27oTf2j/wbNgq8qj/crN8xsSx7pzX0sc=
github.com/hashicorp/consul-template v0.25.0 h1:wsnv4jSqBIVzlg6U0wNg+ePzfrsF3Vi9MqIqDEUrg9U=
github.com/hashicorp/consul-template v0.25.0/go.mod h1:/vUsrJvDuuQHcxEw0zik+YXTS7ZKWZjQeaQhshBmfH0=
github.com/hashicorp/consul/api v1.4.0 h1:jfESivXnO5uLdH650JU/6AnjRoHrLhULq0FnC3Kp9EY=
github.com/hashicorp/consul/api v1.4.0/go.mod h1:xc8u05kyMa3Wjr9eEAsIAo3dg8+LywT5E/Cl7cNS5nU=
github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8=
github.com/hashicorp/consul/sdk v0.2.0/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8=
github.com/hashicorp/consul/sdk v0.4.0/go.mod h1:fY08Y9z5SvJqevyZNy6WWPXiG3KwBPAvlcdx16zZ0fM=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-alpnmux v0.0.0-20200513011953-0293f5d23c31 h1:pxqI71/0R1WIASjQEJ9W9skCKYiREEkRoXFvHCZH1pg=
@ -657,8 +656,8 @@ github.com/hashicorp/raft-snapshot v1.0.2-0.20190827162939-8117efcc5aab/go.mod h
github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc=
github.com/hashicorp/serf v0.8.3 h1:MWYcmct5EtKz0efYooPcL0yNkem+7kWxqXDi/UIh+8k=
github.com/hashicorp/serf v0.8.3/go.mod h1:UpNcs7fFbpKIyZaUuSW6EPiH+eZC7OuyFD+wc1oal+k=
github.com/hashicorp/vault v1.2.1-0.20200519221912-a8c2591d3641 h1:sFzGh4bThye9dcV3XhO6xOtWlhXY73cgfod/B3eqdf0=
github.com/hashicorp/vault v1.2.1-0.20200519221912-a8c2591d3641/go.mod h1:AxZRQLEuX4UQt4GIITGJE/s/dpQJjwQnPfufvchqxFM=
github.com/hashicorp/vault v1.2.1-0.20200521015612-812a92b26b19 h1:UnyS8rJUziC7TLs/0blpiCpbnq3l3fwzkpZG+VLLUWI=
github.com/hashicorp/vault v1.2.1-0.20200521015612-812a92b26b19/go.mod h1:m42VsHQcRRfOQWFUF7eDOjYE18W+ZH2KgQzJezt7YKo=
github.com/hashicorp/vault-plugin-auth-alicloud v0.5.5 h1:JYf3VYpKs7mOdtcwZWi73S82oXrC/JR7uoPVUd8c4Hk=
github.com/hashicorp/vault-plugin-auth-alicloud v0.5.5/go.mod h1:sQ+VNwPQlemgXHXikYH6onfH9gPwDZ1GUVRLz0ZvHx8=
github.com/hashicorp/vault-plugin-auth-azure v0.5.5 h1:kN79ai+aMVU9hUmwscHjmweW2fGa8V/t+ScIchPZGrk=
@ -682,8 +681,7 @@ github.com/hashicorp/vault-plugin-database-elasticsearch v0.5.4 h1:YE4qndazWmYGp
github.com/hashicorp/vault-plugin-database-elasticsearch v0.5.4/go.mod h1:QjGrrxcRXv/4XkEZAlM0VMZEa3uxKAICFqDj27FP/48=
github.com/hashicorp/vault-plugin-database-mongodbatlas v0.1.1 h1:fA6cFH8lIPH2M4KNTEzf1bpc6Tbyy5ZvoYP8H/TI9ts=
github.com/hashicorp/vault-plugin-database-mongodbatlas v0.1.1/go.mod h1:MP3kfr0N+7miOTZFwKv952b9VkXM4S2Q6YtQCiNKWq8=
github.com/hashicorp/vault-plugin-secrets-ad v0.6.5 h1:wrHzXSD6qmKvkuHaQn+BNj89+HGhMNchxAckGnd7YTc=
github.com/hashicorp/vault-plugin-secrets-ad v0.6.5/go.mod h1:kk98nB+cwDbt3I7UGQq3ota7+eHZrGSTQZfSRGpluvA=
github.com/hashicorp/vault-plugin-secrets-ad v0.6.4-beta1.0.20200518124111-3dceeb3ce90e/go.mod h1:SCsKcChP8yrtOHXOeTD7oRk0oflj3IxA9y9zTOGtQ8s=
github.com/hashicorp/vault-plugin-secrets-alicloud v0.5.5 h1:BOOtSls+BQ1EtPmpE9LoqZztsEZ1fRWVSkHWtRIrCB4=
github.com/hashicorp/vault-plugin-secrets-alicloud v0.5.5/go.mod h1:gAoReoUpBHaBwkxQqTK7FY8nQC0MuaZHLiW5WOSny5g=
github.com/hashicorp/vault-plugin-secrets-azure v0.5.6 h1:4PgQ5rCT29wW5PMyebEhPkEYuR5s+SnInuZz3x2cP50=
@ -1393,7 +1391,6 @@ golang.org/x/sys v0.0.0-20190616124812-15dcb6c0061f/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190730183949-1393eb018365/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

@ -40,6 +40,8 @@ listener "tcp" {
tls_disable = true
proxy_protocol_behavior = "allow_authorized"
proxy_protocol_authorized_addrs = "127.0.0.1"
cors_enabled = true
cors_allowed_origins = ["*"]
}
listener "tcp" {

@ -0,0 +1,78 @@
package config
import (
"testing"
"time"
"github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/stretchr/testify/assert"
)
func TestDevController(t *testing.T) {
actual, err := DevController()
if err != nil {
t.Fatal(err)
}
addr, err := sockaddr.NewIPAddr("127.0.0.1")
if err != nil {
t.Fatal(err)
}
exp := &Config{
SharedConfig: &configutil.SharedConfig{
DisableMlock: true,
Listeners: []*configutil.Listener{
{
Type: "tcp",
Purpose: []string{"api"},
TLSDisable: true,
ProxyProtocolBehavior: "allow_authorized",
ProxyProtocolAuthorizedAddrs: []*sockaddr.SockAddrMarshaler{
{SockAddr: addr},
},
CorsEnabled: true,
CorsAllowedOrigins: []string{"*"},
},
{
Type: "tcp",
Purpose: []string{"cluster"},
TLSDisable: true,
ProxyProtocolBehavior: "allow_authorized",
ProxyProtocolAuthorizedAddrs: []*sockaddr.SockAddrMarshaler{
{SockAddr: addr},
},
},
},
Seals: []*configutil.KMS{
{
Type: "aead",
Purpose: []string{"controller"},
Config: map[string]string{
"aead_type": "aes-gcm",
},
},
{
Type: "aead",
Purpose: []string{"worker-auth"},
Config: map[string]string{
"aead_type": "aes-gcm",
},
},
},
Telemetry: &configutil.Telemetry{
DisableHostname: true,
PrometheusRetentionTime: time.Hour * 24,
},
},
DevController: true,
}
exp.Listeners[0].RawConfig = actual.Listeners[0].RawConfig
exp.Listeners[1].RawConfig = actual.Listeners[1].RawConfig
exp.Seals[0].Config["key"] = actual.Seals[0].Config["key"]
exp.Seals[1].Config["key"] = actual.Seals[1].Config["key"]
assert.Equal(t, exp, actual)
}

@ -0,0 +1,220 @@
package controller
import (
"fmt"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/hashicorp/watchtower/internal/cmd/config"
)
const corsTestConfig = `
disable_mlock = true
telemetry {
prometheus_retention_time = "24h"
disable_hostname = true
}
kms "aead" {
purpose = "controller"
aead_type = "aes-gcm"
key = "09iqFxRJNYsl/b8CQxjnGw=="
}
kms "aead" {
purpose = "worker-auth"
aead_type = "aes-gcm"
key = "09iqFxRJNYsl/b8CQxjnGw=="
}
listener "tcp" {
purpose = "api"
tls_disable = true
cors_enabled = false
}
listener "tcp" {
purpose = "api"
tls_disable = true
cors_enabled = true
cors_allowed_origins = []
}
listener "tcp" {
purpose = "api"
tls_disable = true
cors_enabled = true
cors_allowed_origins = ["foobar.com", "barfoo.com"]
}
listener "tcp" {
purpose = "api"
tls_disable = true
cors_enabled = true
cors_allowed_origins = ["*"]
cors_allowed_headers = ["x-foobar"]
}
`
func TestHandler_CORS(t *testing.T) {
cfg, err := config.Parse(corsTestConfig)
if err != nil {
t.Fatal(err)
}
tc := NewTestController(t, &TestControllerOpts{
Config: cfg,
DefaultOrgId: "o_1234567890",
})
defer tc.Shutdown()
cases := []struct {
name string
method string
origin string
code int
acrmHeader string
allowedHeader string
listenerNum int
}{
{
"disabled no origin",
http.MethodPost,
"",
http.StatusOK,
"",
"",
1,
},
{
"disabled with origin",
http.MethodPost,
"foobar.com",
http.StatusOK,
"",
"",
1,
},
{
"enabled with no allowed origins and no origin defined",
http.MethodPost,
"",
http.StatusOK,
"",
"",
2,
},
{
"enabled with no allowed origins and origin defined",
http.MethodPost,
"foobar.com",
http.StatusForbidden,
"",
"",
2,
},
{
"enabled with allowed origins and no origin defined",
http.MethodPost,
"",
http.StatusOK,
"",
"",
3,
},
{
"enabled with allowed origins and bad origin defined",
http.MethodPost,
"flubber.com",
http.StatusForbidden,
"",
"",
3,
},
{
"enabled with allowed origins and good origin defined",
http.MethodPost,
"barfoo.com",
http.StatusOK,
"",
"",
3,
},
{
"enabled with wildcard origins and no origin defined",
http.MethodPost,
"",
http.StatusOK,
"",
"",
4,
},
{
"enabled with wildcard origins and origin defined",
http.MethodPost,
"flubber.com",
http.StatusOK,
"",
"",
4,
},
{
"wildcard origins with method list and good method",
http.MethodOptions,
"flubber.com",
http.StatusNoContent,
"DELETE",
"",
4,
},
{
"wildcard origins with method list and bad method",
http.MethodOptions,
"flubber.com",
http.StatusMethodNotAllowed,
"BADSTUFF",
"X-Foobar",
4,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
// Create a client with the right address
client := tc.Client()
client.SetAddr(tc.ApiAddrs()[c.listenerNum-1])
// Create the request
req, err := client.NewRequest(tc.Context(), c.method, "orgs/o_1234567890/projects", nil)
assert.NoError(t, err)
// Append headers
if c.origin != "" {
req.Header.Add("Origin", c.origin)
}
if c.acrmHeader != "" {
req.Header.Add("Access-Control-Request-Method", c.acrmHeader)
}
// Run the request, do basic checks
resp, err := client.Do(req)
assert.NoError(t, err)
assert.Equal(t, c.code, resp.HttpResponse().StatusCode)
// If options and we expect it to be successful, run some checks
if req.Method == http.MethodOptions && c.code == http.StatusNoContent {
assert.Equal(t, fmt.Sprintf("%s, %s, %s, %s, %s", http.MethodDelete, http.MethodGet, http.MethodOptions, http.MethodPost, http.MethodPatch), resp.HttpResponse().Header.Get("Access-Control-Allow-Methods"))
assert.Equal(t, fmt.Sprintf("%s, %s, %s, %s", "Content-Type", "X-Requested-With", "Authorization", "X-Foobar"), resp.HttpResponse().Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "300", resp.HttpResponse().Header.Get("Access-Control-Max-Age"))
}
// If origin was set and we expect it to be successful, run some more checks
if c.origin != "" && c.code == http.StatusOK && c.listenerNum > 1 {
assert.Equal(t, c.origin, resp.HttpResponse().Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "Origin", resp.HttpResponse().Header.Get("Vary"))
}
})
}
}

@ -2,6 +2,7 @@ package controller
import (
"context"
"encoding/json"
"net/http"
"path"
"strings"
@ -9,6 +10,8 @@ import (
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/watchtower/api"
"github.com/hashicorp/watchtower/globals"
"github.com/hashicorp/watchtower/internal/gen/controller/api/services"
"github.com/hashicorp/watchtower/internal/servers/controller/handlers/host_catalogs"
@ -29,9 +32,10 @@ func (c *Controller) handler(props HandlerProperties) http.Handler {
mux.Handle("/v1/", handleGrpcGateway(c))
genericWrappedHandler := wrapGenericHandler(mux, c, props)
corsWrappedHandler := wrapHandlerWithCors(mux, props)
commonWrappedHandler := wrapHandlerWithCommonFuncs(corsWrappedHandler, c, props)
return genericWrappedHandler
return commonWrappedHandler
}
func handleGrpcGateway(c *Controller) http.Handler {
@ -47,7 +51,7 @@ func handleGrpcGateway(c *Controller) http.Handler {
return mux
}
func wrapGenericHandler(h http.Handler, c *Controller, props HandlerProperties) http.Handler {
func wrapHandlerWithCommonFuncs(h http.Handler, c *Controller, props HandlerProperties) http.Handler {
var maxRequestDuration time.Duration
var maxRequestSize int64
if props.ListenerConfig != nil {
@ -94,6 +98,86 @@ func wrapGenericHandler(h http.Handler, c *Controller, props HandlerProperties)
})
}
func wrapHandlerWithCors(h http.Handler, props HandlerProperties) http.Handler {
allowedMethods := []string{
http.MethodDelete,
http.MethodGet,
http.MethodOptions,
http.MethodPost,
http.MethodPatch,
}
allowedOrigins := props.ListenerConfig.CorsAllowedOrigins
allowedHeaders := append([]string{
"Content-Type",
"X-Requested-With",
"Authorization",
}, props.ListenerConfig.CorsAllowedHeaders...)
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if !props.ListenerConfig.CorsEnabled {
h.ServeHTTP(w, req)
return
}
origin := req.Header.Get("Origin")
if origin == "" {
// Serve directly
h.ServeHTTP(w, req)
return
}
// Check origin
var valid bool
switch {
case len(allowedOrigins) == 0:
// not valid
case len(allowedOrigins) == 1 && allowedOrigins[0] == "*":
valid = true
default:
valid = strutil.StrListContains(allowedOrigins, origin)
}
if !valid {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
err := &api.Error{
Status: api.Int(http.StatusForbidden),
Code: api.String("origin forbidden"),
}
enc := json.NewEncoder(w)
enc.Encode(err)
return
}
if req.Method == http.MethodOptions &&
!strutil.StrListContains(allowedMethods, req.Header.Get("Access-Control-Request-Method")) {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
// Apply headers for preflight requests
if req.Method == http.MethodOptions {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ", "))
w.Header().Set("Access-Control-Allow-Headers", strings.Join(allowedHeaders, ", "))
w.Header().Set("Access-Control-Max-Age", "300")
w.WriteHeader(http.StatusNoContent)
return
}
h.ServeHTTP(w, req)
return
})
}
/*
func WrapForwardedForHandler(h http.Handler, authorizedAddrs []*sockaddr.SockAddrMarshaler, rejectNotPresent, rejectNonAuthz bool, hopSkips int) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

@ -48,3 +48,4 @@ func TestHandleGrpcGateway(t *testing.T) {
})
}
}

@ -19,7 +19,7 @@ type TestController struct {
b *base.Server
c *Controller
t *testing.T
addr string // The address the Controller API is listening on
addrs []string // The address the Controller API is listening on
client *api.Client
ctx context.Context
cancel context.CancelFunc
@ -42,27 +42,23 @@ func (tc *TestController) Cancel() {
tc.cancel()
}
func (tc *TestController) ApiAddress() string {
if tc.addr != "" {
return tc.addr
func (tc *TestController) ApiAddrs() []string {
if tc.addrs != nil {
return tc.addrs
}
var apiLn *base.ServerListener
for _, listener := range tc.b.Listeners {
if listener.Config.Purpose[0] == "api" {
apiLn = listener
break
tcpAddr, ok := listener.Mux.Addr().(*net.TCPAddr)
if !ok {
tc.t.Fatal("could not parse address as a TCP addr")
}
addr := fmt.Sprintf("http://%s:%d", tcpAddr.IP.String(), tcpAddr.Port)
tc.addrs = append(tc.addrs, addr)
}
}
if apiLn == nil {
tc.t.Fatal("could not find api listener")
}
tcpAddr, ok := apiLn.Mux.Addr().(*net.TCPAddr)
if !ok {
tc.t.Fatal("could not parse address as a TCP addr")
}
tc.addr = fmt.Sprintf("http://%s:%d", tcpAddr.IP.String(), tcpAddr.Port)
return tc.addr
return tc.addrs
}
func (tc *TestController) buildClient() {
@ -70,7 +66,11 @@ func (tc *TestController) buildClient() {
if err != nil {
tc.t.Fatal(fmt.Errorf("error creating client: %w", err))
}
if err := client.SetAddr(tc.ApiAddress()); err != nil {
apiAddrs := tc.ApiAddrs()
if len(apiAddrs) == 0 {
tc.t.Fatal("no API addresses found")
}
if err := client.SetAddr(apiAddrs[0]); err != nil {
tc.t.Fatal(fmt.Errorf("error setting client address: %w", err))
}

@ -91,11 +91,15 @@ func WithDefaultOrgId(id string) Option {
// NewTestController blocks until a new TestController is created, returns the url for the TestController and a function
// that can be called to tear down the controller after it has been used for testing.
func NewTestController(t *testing.T, opt ...Option) (string, func()) {
func NewTestController(t *testing.T, opt ...Option) *TestController {
conf, err := getOpts(opt...)
if err != nil {
t.Fatalf("Couldn't create TestController: %v", err)
}
tc := controller.NewTestController(t, conf)
return tc.ApiAddress(), tc.Shutdown
return &TestController{TestController: tc}
}
type TestController struct {
*controller.TestController
}

Loading…
Cancel
Save