You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/daemon/common/handler_test.go

590 lines
17 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package common
import (
"bufio"
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"
"time"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/event"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/eventlogger/filters/gated"
"github.com/hashicorp/eventlogger/formatter_filters/cloudevents"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-sockaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_WrapWithOptionals(t *testing.T) {
t.Parallel()
ctx := context.TODO()
w := httptest.NewRecorder()
testWriterWrapper := writerWrapper{w, 0}
type testNoOptional struct {
http.ResponseWriter
}
type testPusherHijacker struct {
http.ResponseWriter
testHijacker
testPusher
}
type testPusherFlusher struct {
http.ResponseWriter
testPusher
testFlusher
}
type testFlusherHijacker struct {
http.ResponseWriter
testFlusher
testHijacker
}
type testAll struct {
http.ResponseWriter
testFlusher
testHijacker
testPusher
}
tests := []struct {
name string
with *writerWrapper
wrap http.ResponseWriter
wantFlusher bool
wantPusher bool
wantHijacker bool
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "missing-test-writer",
wrap: &testFlusher{},
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing writer wrapper",
},
{
name: "missing-wrapper",
with: &testWriterWrapper,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing response writer",
},
{
name: "success-no-optional",
with: &testWriterWrapper,
wrap: &testNoOptional{},
},
{
name: "success-flusher",
with: &testWriterWrapper,
wrap: &testFlusher{},
wantFlusher: true,
},
{
name: "success-flusher-hijacker",
with: &testWriterWrapper,
wrap: &testFlusherHijacker{},
wantFlusher: true,
},
{
name: "success-pusher",
with: &testWriterWrapper,
wrap: &testPusher{},
wantPusher: true,
},
{
name: "success-pusher-hijacker",
with: &testWriterWrapper,
wrap: &testPusherHijacker{},
wantHijacker: true,
wantPusher: true,
},
{
name: "success-pusher-flusher",
with: &testWriterWrapper,
wrap: &testPusherFlusher{},
wantFlusher: true,
wantPusher: true,
},
{
name: "success-hijacker",
with: &testWriterWrapper,
wrap: &testHijacker{},
wantHijacker: true,
},
{
name: "success-all",
with: &testWriterWrapper,
wrap: &testAll{},
wantHijacker: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
wrapped, err := WrapWithOptionals(ctx, tt.with, tt.wrap)
if tt.wantErrMatch != nil {
require.Error(err)
assert.Nil(wrapped)
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted %q and got %q", tt.wantErrMatch.Code, err.Error())
if tt.wantErrContains != "" {
assert.Contains(err.Error(), tt.wantErrContains)
}
return
}
require.NoError(err)
require.NotNil(wrapped)
_, ok := wrapped.(interface{ StatusCode() int })
assert.Truef(ok, "wanted an response writer that satisfied the StatusCode interface")
if tt.wantPusher {
_, ok := wrapped.(http.Pusher)
assert.Truef(ok, "wanted an response writer that satisfied the http.Pusher interface")
}
if tt.wantHijacker {
_, ok := wrapped.(http.Hijacker)
assert.Truef(ok, "wanted an response writer that satisfied the http.Hijacker interface")
}
if tt.wantFlusher {
_, ok := wrapped.(http.Flusher)
assert.Truef(ok, "wanted an response writer that satisfied the http.Flusher interface")
}
})
}
}
func Test_WrapWithEventsHandler(t *testing.T) {
// This cannot run in parallel because it relies on a pkg var common.privateNets
wrapper := db.TestWrapper(t)
conn, _ := db.TestSetup(t, "postgres")
testKms := kms.TestKms(t, conn, wrapper)
testHander := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
fmt.Fprintln(w, "I'm a little teapot short and stout")
})
goodAddr, err := sockaddr.NewIPAddr("127.0.0.1")
require.NoError(t, err)
testListenerCfg := cfgListener(goodAddr)
testListenerCfg.XForwardedForRejectNotPresent = false
c := event.TestEventerConfig(t, "Test_WrapWithEventsHandler", event.TestWithAuditSink(t), event.TestWithObservationSink(t))
testLock := &sync.Mutex{}
testLogger := hclog.New(&hclog.LoggerOptions{
Mutex: testLock,
Name: "test",
})
testEventer, err := event.NewEventer(testLogger, testLock, "Test_WrapWithEventsHandler", c.EventerConfig)
require.NoError(t, err)
tests := []struct {
name string
h http.Handler
e *event.Eventer
kms *kms.Kms
statusCode int
noEventJson bool
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "missing handler",
e: testEventer,
kms: testKms,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing handler",
},
{
name: "missing eventer",
h: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, client")
}),
kms: testKms,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing eventer",
},
{
name: "missing kms",
h: testHander,
e: testEventer,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing kms",
},
{
name: "audit-startGatedEvents",
h: testHander,
e: func() *event.Eventer {
b := &testMockBroker{errorOnSendAudit: true}
c := event.EventerConfig{AuditEnabled: true}
e, err := event.NewEventer(testLogger, testLock, "audit-startGatedEvents", c, event.TestWithBroker(t, b))
require.NoError(t, err)
return e
}(),
kms: testKms,
statusCode: http.StatusInternalServerError,
noEventJson: true,
},
{
name: "audit-flushGatedEvents",
h: testHander,
e: func() *event.Eventer {
b := &testMockBroker{errorOnFlush: true}
c := event.EventerConfig{AuditEnabled: true}
e, err := event.NewEventer(testLogger, testLock, "audit-flushGatedEvents", c, event.TestWithBroker(t, b))
require.NoError(t, err)
return e
}(),
kms: testKms,
statusCode: http.StatusTeapot, // this isn't ideal, but the write by the test handler will send an teapot status
noEventJson: true,
},
{
name: "success",
h: testHander,
e: testEventer,
kms: testKms,
statusCode: http.StatusTeapot,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
got, err := WrapWithEventsHandler(context.Background(), tt.h, tt.e, tt.kms, testListenerCfg)
if tt.wantErrMatch != nil {
require.Error(err)
assert.Nil(got)
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted %q and got %q", tt.wantErrMatch.Code, err.Error())
if tt.wantErrContains != "" {
assert.Contains(err.Error(), tt.wantErrContains)
}
return
}
require.NoError(err)
assert.NotNil(got)
req, err := http.NewRequest("GET", "/greeting", nil)
require.NoError(err)
rr := httptest.NewRecorder()
got.ServeHTTP(rr, req)
assert.Equal(tt.statusCode, rr.Code)
{ // test that the got observation is what we wanted.
require.NotNil(c.ObservationEvents)
defer func() { _ = os.WriteFile(c.ObservationEvents.Name(), nil, 0o666) }()
b, err := os.ReadFile(c.ObservationEvents.Name())
assert.NoError(err)
if tt.noEventJson {
assert.Lenf(b, 0, "expected no json for internal errors but got %s", string(b))
return
}
got := &cloudevents.Event{}
err = json.Unmarshal(b, got)
require.NoErrorf(err, "json: %s", string(b))
actualJson, err := json.Marshal(got)
require.NoError(err)
// set the got values to the wanted values that are either
// static or calculated in real-time
info := event.RequestInfo{
Method: "GET",
Path: "/greeting",
Id: got.Data.(map[string]any)["request_info"].(map[string]any)["id"].(string),
}
hdr := map[string]any{
"status": http.StatusTeapot,
"start": got.Data.(map[string]any)["start"].(string),
"stop": got.Data.(map[string]any)["stop"].(string),
"latency-ms": got.Data.(map[string]any)["latency-ms"].(float64),
}
wantJson := testJson(t, event.ObservationType, &info, event.Op(tt.name), got, hdr, nil)
assert.JSONEq(string(wantJson), string(actualJson))
}
{ // test that the got audit is what we wanted.
require.NotNil(c.AuditEvents)
defer func() { _ = os.WriteFile(c.AuditEvents.Name(), nil, 0o666) }()
b, err := os.ReadFile(c.AuditEvents.Name())
assert.NoError(err)
got := &cloudevents.Event{}
err = json.Unmarshal(b, got)
require.NoErrorf(err, "json: %s", string(b))
actualJson, err := json.Marshal(got)
require.NoError(err)
// set the got values to the wanted values that are either
// static or calculated in real-time
info := event.RequestInfo{
Method: "GET",
Path: "/greeting",
// Id: got.Data.(map[string]interface{})["id"].(string),
Id: got.Data.(map[string]any)["request_info"].(map[string]any)["id"].(string),
}
hdr := map[string]any{
"id": got.Data.(map[string]any)["id"].(string),
"timestamp": got.Data.(map[string]any)["timestamp"].(string),
"response": got.Data.(map[string]any)["response"].(map[string]any),
}
wantJson := testJson(t, event.AuditType, &info, event.Op(tt.name), got, hdr, nil)
assert.JSONEq(string(wantJson), string(actualJson))
}
})
}
}
func Test_startGatedEvents(t *testing.T) {
testStartTime := time.Now()
tests := []struct {
name string
errOnAudit bool
errOnObservation bool
startTime time.Time
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "audit-failed",
errOnAudit: true,
startTime: testStartTime,
wantErrMatch: errors.T(errors.Internal),
},
{
name: "observation-failed",
errOnObservation: true,
wantErrMatch: errors.T(errors.Internal),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
b := &testMockBroker{
errorOnSendAudit: tt.errOnAudit,
errorOnSendObservation: tt.errOnObservation,
}
config := event.EventerConfig{
AuditEnabled: true,
ObservationsEnabled: true,
}
testLock := &sync.Mutex{}
testLogger := hclog.New(&hclog.LoggerOptions{
Mutex: testLock,
Name: "test",
})
e, err := event.NewEventer(testLogger, testLock, tt.name, config, event.TestWithBroker(t, b))
require.NoError(err)
ctx, err := event.NewEventerContext(context.Background(), e)
require.NoError(err)
err = startGatedEvents(ctx, "GET", "/hello", tt.startTime)
if tt.wantErrMatch != nil {
require.Error(err)
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted %q and got %q", tt.wantErrMatch.Code, err.Error())
if tt.wantErrContains != "" {
assert.Contains(err.Error(), tt.wantErrContains)
}
return
}
require.NoError(err)
})
}
}
func Test_flushGatedEvents(t *testing.T) {
testStartTime := time.Now()
tests := []struct {
name string
errOnAudit bool
errOnObservation bool
startTime time.Time
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "audit-failed",
errOnAudit: true,
startTime: testStartTime,
wantErrMatch: errors.T(errors.Internal),
},
{
name: "observation-failed",
errOnObservation: true,
wantErrMatch: errors.T(errors.Internal),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
b := &testMockBroker{
errorOnSendAudit: tt.errOnAudit,
errorOnSendObservation: tt.errOnObservation,
}
config := event.EventerConfig{
AuditEnabled: true,
ObservationsEnabled: true,
}
testLock := &sync.Mutex{}
testLogger := hclog.New(&hclog.LoggerOptions{
Mutex: testLock,
Name: "test",
})
e, err := event.NewEventer(testLogger, testLock, tt.name, config, event.TestWithBroker(t, b))
require.NoError(err)
ctx, err := event.NewEventerContext(context.Background(), e)
require.NoError(err)
err = flushGatedEvents(ctx, "GET", "/hello", 200, tt.startTime)
if tt.wantErrMatch != nil {
require.Error(err)
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted %q and got %q", tt.wantErrMatch.Code, err.Error())
if tt.wantErrContains != "" {
assert.Contains(err.Error(), tt.wantErrContains)
}
return
}
require.NoError(err)
})
}
}
type testMockBroker struct {
errorOnSendAudit bool
errorOnSendObservation bool
errorOnFlush bool
}
func (b *testMockBroker) Send(ctx context.Context, t eventlogger.EventType, payload any) (eventlogger.Status, error) {
const op = "common.(testMockBroker).Send"
_, isGateable := payload.(gated.Gateable)
switch {
case b.errorOnFlush && isGateable && payload.(gated.Gateable).FlushEvent():
return eventlogger.Status{}, errors.New(ctx, errors.Internal, op, "unable to flush event")
case b.errorOnSendAudit && t == eventlogger.EventType(event.AuditType):
return eventlogger.Status{}, errors.New(ctx, errors.Internal, op, "unable to send audit event")
case b.errorOnSendObservation && t == eventlogger.EventType(event.ObservationType):
return eventlogger.Status{}, errors.New(ctx, errors.Internal, op, "unable to send observation event")
}
return eventlogger.Status{}, nil
}
func (b *testMockBroker) Reopen(ctx context.Context) error { return nil }
func (b *testMockBroker) RegisterPipeline(def eventlogger.Pipeline, opt ...eventlogger.Option) error {
return nil
}
func (b *testMockBroker) StopTimeAt(t time.Time) {}
func (b *testMockBroker) RegisterNode(id eventlogger.NodeID, node eventlogger.Node, opt ...eventlogger.Option) error {
return nil
}
func (b *testMockBroker) RemoveNode(ctx context.Context, id eventlogger.NodeID) error {
return nil
}
func (b *testMockBroker) RemovePipelineAndNodes(ctx context.Context, t eventlogger.EventType, id eventlogger.PipelineID) (bool, error) {
return true, nil
}
func (b *testMockBroker) SetSuccessThreshold(t eventlogger.EventType, successThreshold int) error {
return nil
}
type eventJson struct {
CreatedAt string `json:"created_at"`
EventType string `json:"event_type"`
Payload map[string]any `json:"payload"`
}
func testJson(t *testing.T, eventType event.Type, reqInfo *event.RequestInfo, caller event.Op, got *cloudevents.Event, hdr, details map[string]any) []byte {
t.Helper()
const (
testAuditVersion = "v0.1"
testErrorVersion = "v0.1"
testObservationVersion = "v0.1"
)
require := require.New(t)
var payload map[string]any
switch eventType {
case event.ObservationType:
payload = map[string]any{
event.RequestInfoField: reqInfo,
event.VersionField: testObservationVersion,
}
for k, v := range hdr {
payload[k] = v
}
case event.AuditType:
payload = map[string]any{
event.IdField: got.Data.(map[string]any)[event.IdField].(string),
event.RequestInfoField: reqInfo,
event.VersionField: testAuditVersion,
event.TypeField: event.ApiRequest,
}
for k, v := range hdr {
payload[k] = v
}
}
j := cloudevents.Event{
ID: got.ID,
Time: got.Time,
Source: got.Source,
SpecVersion: got.SpecVersion,
Type: got.Type,
DataContentType: got.DataContentType,
Data: payload,
}
if details != nil {
details[event.OpField] = string(caller)
d := got.Data.(map[string]any)[event.DetailsField].([]any)[0].(map[string]any)
j.Data.(map[string]any)[event.DetailsField] = []struct {
CreatedAt string `json:"created_at"`
Type string `json:"type"`
Payload map[string]any `json:"payload"`
}{
{
CreatedAt: d[event.CreatedAtField].(string),
Type: d[event.TypeField].(string),
Payload: details,
},
}
}
b, err := json.Marshal(j)
require.NoError(err)
return b
}
type testFlusher struct {
http.ResponseWriter
}
func (t *testFlusher) Flush() {}
type testPusher struct {
http.ResponseWriter
}
func (t *testPusher) Push(target string, opts *http.PushOptions) error { return nil }
type testHijacker struct {
http.ResponseWriter
}
func (t *testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return nil, nil, nil }