You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/servers/controller/cors_test.go

221 lines
4.6 KiB

package controller
import (
"fmt"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/hashicorp/watchtower/internal/cmd/config"
)
const corsTestConfig = `
disable_mlock = true
telemetry {
prometheus_retention_time = "24h"
disable_hostname = true
}
kms "aead" {
purpose = "controller"
aead_type = "aes-gcm"
key = "09iqFxRJNYsl/b8CQxjnGw=="
}
kms "aead" {
purpose = "worker-auth"
aead_type = "aes-gcm"
key = "09iqFxRJNYsl/b8CQxjnGw=="
}
listener "tcp" {
purpose = "api"
tls_disable = true
cors_enabled = false
}
listener "tcp" {
purpose = "api"
tls_disable = true
cors_enabled = true
cors_allowed_origins = []
}
listener "tcp" {
purpose = "api"
tls_disable = true
cors_enabled = true
cors_allowed_origins = ["foobar.com", "barfoo.com"]
}
listener "tcp" {
purpose = "api"
tls_disable = true
cors_enabled = true
cors_allowed_origins = ["*"]
cors_allowed_headers = ["x-foobar"]
}
`
func TestHandler_CORS(t *testing.T) {
cfg, err := config.Parse(corsTestConfig)
if err != nil {
t.Fatal(err)
}
tc := NewTestController(t, &TestControllerOpts{
Config: cfg,
DefaultOrgId: "o_1234567890",
})
defer tc.Shutdown()
cases := []struct {
name string
method string
origin string
code int
acrmHeader string
allowedHeader string
listenerNum int
}{
{
"disabled no origin",
http.MethodPost,
"",
http.StatusOK,
"",
"",
1,
},
{
"disabled with origin",
http.MethodPost,
"foobar.com",
http.StatusOK,
"",
"",
1,
},
{
"enabled with no allowed origins and no origin defined",
http.MethodPost,
"",
http.StatusOK,
"",
"",
2,
},
{
"enabled with no allowed origins and origin defined",
http.MethodPost,
"foobar.com",
http.StatusForbidden,
"",
"",
2,
},
{
"enabled with allowed origins and no origin defined",
http.MethodPost,
"",
http.StatusOK,
"",
"",
3,
},
{
"enabled with allowed origins and bad origin defined",
http.MethodPost,
"flubber.com",
http.StatusForbidden,
"",
"",
3,
},
{
"enabled with allowed origins and good origin defined",
http.MethodPost,
"barfoo.com",
http.StatusOK,
"",
"",
3,
},
{
"enabled with wildcard origins and no origin defined",
http.MethodPost,
"",
http.StatusOK,
"",
"",
4,
},
{
"enabled with wildcard origins and origin defined",
http.MethodPost,
"flubber.com",
http.StatusOK,
"",
"",
4,
},
{
"wildcard origins with method list and good method",
http.MethodOptions,
"flubber.com",
http.StatusNoContent,
"DELETE",
"",
4,
},
{
"wildcard origins with method list and bad method",
http.MethodOptions,
"flubber.com",
http.StatusMethodNotAllowed,
"BADSTUFF",
"X-Foobar",
4,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
// Create a client with the right address
client := tc.Client()
client.SetAddr(tc.ApiAddrs()[c.listenerNum-1])
// Create the request
req, err := client.NewRequest(tc.Context(), c.method, "orgs/o_1234567890/projects", nil)
assert.NoError(t, err)
// Append headers
if c.origin != "" {
req.Header.Add("Origin", c.origin)
}
if c.acrmHeader != "" {
req.Header.Add("Access-Control-Request-Method", c.acrmHeader)
}
// Run the request, do basic checks
resp, err := client.Do(req)
assert.NoError(t, err)
assert.Equal(t, c.code, resp.HttpResponse().StatusCode)
// If options and we expect it to be successful, run some checks
if req.Method == http.MethodOptions && c.code == http.StatusNoContent {
assert.Equal(t, fmt.Sprintf("%s, %s, %s, %s, %s", http.MethodDelete, http.MethodGet, http.MethodOptions, http.MethodPost, http.MethodPatch), resp.HttpResponse().Header.Get("Access-Control-Allow-Methods"))
assert.Equal(t, fmt.Sprintf("%s, %s, %s, %s", "Content-Type", "X-Requested-With", "Authorization", "X-Foobar"), resp.HttpResponse().Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "300", resp.HttpResponse().Header.Get("Access-Control-Max-Age"))
}
// If origin was set and we expect it to be successful, run some more checks
if c.origin != "" && c.code == http.StatusOK && c.listenerNum > 1 {
assert.Equal(t, c.origin, resp.HttpResponse().Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "Origin", resp.HttpResponse().Header.Get("Vary"))
}
})
}
}