You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/daemon/controller/interceptor_test.go

1182 lines
38 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package controller
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/alias"
talias "github.com/hashicorp/boundary/internal/alias/target"
"github.com/hashicorp/boundary/internal/authtoken"
"github.com/hashicorp/boundary/internal/daemon/controller/auth"
"github.com/hashicorp/boundary/internal/daemon/controller/common"
"github.com/hashicorp/boundary/internal/daemon/controller/handlers"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/event"
pb_api "github.com/hashicorp/boundary/internal/gen/controller/api"
pbs "github.com/hashicorp/boundary/internal/gen/controller/api/services"
authpb "github.com/hashicorp/boundary/internal/gen/controller/auth"
pberrors "github.com/hashicorp/boundary/internal/gen/errors"
"github.com/hashicorp/boundary/internal/gen/testing/interceptor"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/server"
"github.com/hashicorp/boundary/internal/target/tcp"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-uuid"
"github.com/mr-tron/base58"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/genproto/googleapis/api/httpbody"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
)
func Test_unaryCtxInterceptor(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
iamRepoFn := func() (*iam.Repository, error) {
return iamRepo, nil
}
atRepoFn := func() (*authtoken.Repository, error) {
return authtoken.NewRepository(ctx, rw, rw, kmsCache)
}
serversRepoFn := func() (*server.Repository, error) {
return server.NewRepository(ctx, rw, rw, kmsCache)
}
validGatewayTicket := "valid-ticket"
o, _ := iam.TestScopes(t, iamRepo)
at := authtoken.TestAuthToken(t, conn, kmsCache, o.GetPublicId())
encToken, err := authtoken.EncryptToken(context.Background(), kmsCache, o.GetPublicId(), at.GetPublicId(), at.GetToken())
require.NoError(t, err)
tokValue := at.GetPublicId() + "_" + encToken
newReqCtx := func(gwTicket string) context.Context {
req := httptest.NewRequest("GET", "http://127.0.0.1/v1/scopes/o_1", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tokValue))
// Add values for authn/authz checking
requestInfo := authpb.RequestInfo{
Path: req.URL.Path,
Method: req.Method,
EventId: "test-event-id",
TraceId: "test-trace-id",
}
requestInfo.PublicId, requestInfo.EncryptedToken, requestInfo.TokenFormat = auth.GetTokenFromRequest(context.TODO(), kmsCache, req)
requestInfo.Ticket = gwTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy
marshalledRequestInfo, err := proto.Marshal(&requestInfo)
require.NoError(t, err)
md := metadata.Pairs(requestInfoMdKey, base58.FastBase58Encoding(marshalledRequestInfo))
mdCtx := metadata.NewIncomingContext(context.Background(), md)
md, ok := metadata.FromIncomingContext(mdCtx)
require.True(t, ok)
require.NotNil(t, md)
return mdCtx
}
factoryCtx := context.Background()
returnCtxHandler := func(ctx context.Context, req any) (any, error) {
return ctx, nil
}
c := event.TestEventerConfig(t, "Test_unaryCtxInterceptor", event.TestWithAuditSink(t), event.TestWithObservationSink(t))
testLock := &sync.Mutex{}
testLogger := hclog.New(&hclog.LoggerOptions{
Mutex: testLock,
Name: "test",
})
testEventer, err := event.NewEventer(testLogger, testLock, "Test_unaryCtxInterceptor", c.EventerConfig)
require.NoError(t, err)
tests := []struct {
name string
requestCtx context.Context
iamRepoFn common.IamRepoFactory
authTokenRepoFn common.AuthTokenRepoFactory
serversRepoFn common.ServersRepoFactory
kms *kms.Kms
eventer *event.Eventer
ticket string
wantFactoryErr bool
wantFactoryErrMatch *errors.Template
wantFactoryErrContains string
wantRequestErr bool
wantRequestErrMatch *errors.Template
wantRequestErrContains string
}{
{
name: "missing-iam-repo",
requestCtx: newReqCtx(validGatewayTicket),
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
wantFactoryErr: true,
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
wantFactoryErrContains: "missing iam repo",
},
{
name: "missing-at-repo",
requestCtx: newReqCtx(validGatewayTicket),
iamRepoFn: iamRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
wantFactoryErr: true,
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
wantFactoryErrContains: "missing auth token repo",
},
{
name: "missing-servers-repo",
requestCtx: newReqCtx(validGatewayTicket),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
kms: kmsCache,
eventer: testEventer,
wantFactoryErr: true,
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
wantFactoryErrContains: "missing server repo function",
},
{
name: "missing-kms",
requestCtx: newReqCtx(validGatewayTicket),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
eventer: testEventer,
wantFactoryErr: true,
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
wantFactoryErrContains: "missing kms",
},
{
name: "missing-eventer",
requestCtx: newReqCtx(validGatewayTicket),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
wantFactoryErr: true,
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
wantFactoryErrContains: "missing kms",
},
{
name: "missing-factory-ticket",
requestCtx: newReqCtx(validGatewayTicket),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
wantFactoryErr: true,
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
wantFactoryErrContains: "missing ticket",
},
{
name: "missing-metadata",
requestCtx: context.Background(),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: validGatewayTicket,
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "No metadata",
},
{
name: "missing-request-info-metadata",
requestCtx: func() context.Context {
md := metadata.Pairs("greeter-md", "hello")
return metadata.NewIncomingContext(context.Background(), md)
}(),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: validGatewayTicket,
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "Missing request metadata",
},
{
name: "too-many-request-info-metadata",
requestCtx: func() context.Context {
md := metadata.Pairs(requestInfoMdKey, "first", requestInfoMdKey, "second")
return metadata.NewIncomingContext(context.Background(), md)
}(),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: validGatewayTicket,
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "expected 1 value",
},
{
name: "request-info-metadata-not-encoded",
requestCtx: func() context.Context {
md := metadata.Pairs(requestInfoMdKey, "hello")
return metadata.NewIncomingContext(context.Background(), md)
}(),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: validGatewayTicket,
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "unable to decode request info",
},
{
name: "request-info-metadata-not-proto",
requestCtx: func() context.Context {
md := metadata.Pairs(requestInfoMdKey, base58.FastBase58Encoding([]byte("hello")))
return metadata.NewIncomingContext(context.Background(), md)
}(),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: validGatewayTicket,
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "unable to unmarshal request info",
},
{
name: "request-info-metadata-not-proto",
requestCtx: func() context.Context {
md := metadata.Pairs(requestInfoMdKey, base58.FastBase58Encoding([]byte("hello")))
return metadata.NewIncomingContext(context.Background(), md)
}(),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: validGatewayTicket,
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "unable to unmarshal request info",
},
{
name: "missing-request-ticket",
requestCtx: newReqCtx(""),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: "validGatewayTicket",
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "Invalid context (missing ticket)",
},
{
name: "bad-ticket",
requestCtx: newReqCtx("bad-ticket"),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: "validGatewayTicket",
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "Invalid context (bad ticket)",
},
{
name: "valid",
requestCtx: newReqCtx(validGatewayTicket),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: validGatewayTicket,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
interceptor, err := requestCtxUnaryInterceptor(factoryCtx, tt.iamRepoFn, tt.authTokenRepoFn, tt.serversRepoFn, nil, nil, nil, tt.kms, tt.ticket, tt.eventer)
if tt.wantFactoryErr {
require.Error(err)
assert.Nil(interceptor)
if tt.wantFactoryErrMatch != nil {
assert.Truef(errors.Match(tt.wantFactoryErrMatch, err), "want err code: %q got: %q", tt.wantFactoryErrMatch.Code, err)
}
if tt.wantFactoryErrContains != "" {
assert.Contains(err.Error(), tt.wantFactoryErrContains)
}
return
}
require.NoError(err)
assert.NotNil(interceptor)
info := &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}
retCtx, err := interceptor(tt.requestCtx, nil, info, returnCtxHandler)
if tt.wantRequestErr {
require.Error(err)
assert.Nil(retCtx)
if tt.wantRequestErrMatch != nil {
assert.Truef(errors.Match(tt.wantRequestErrMatch, err), "want err code: %q got: %q", tt.wantRequestErrMatch.Code, err)
}
if tt.wantRequestErrContains != "" {
assert.Contains(err.Error(), tt.wantRequestErrContains)
}
return
}
require.NoError(err)
// Use resource.Scope here but resource type shouldn't matter since we're only validating that auth.Verify
// can use the 'context' with additional ctx.Value from the interceptor
verifyResults := auth.Verify(retCtx.(context.Context), resource.Scope)
assert.NotEmpty(verifyResults)
})
}
}
func Test_streamCtxInterceptor(t *testing.T) {
t.Parallel()
factoryCtx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
iamRepoFn := func() (*iam.Repository, error) {
return iamRepo, nil
}
atRepoFn := func() (*authtoken.Repository, error) {
return authtoken.NewRepository(context.Background(), rw, rw, kmsCache)
}
serversRepoFn := func() (*server.Repository, error) {
return server.NewRepository(factoryCtx, rw, rw, kmsCache)
}
validGatewayTicket := "valid-ticket"
o, _ := iam.TestScopes(t, iamRepo)
at := authtoken.TestAuthToken(t, conn, kmsCache, o.GetPublicId())
encToken, err := authtoken.EncryptToken(factoryCtx, kmsCache, o.GetPublicId(), at.GetPublicId(), at.GetToken())
require.NoError(t, err)
tokValue := at.GetPublicId() + "_" + encToken
newReqCtx := func(gwTicket string) context.Context {
req := httptest.NewRequest("GET", "http://127.0.0.1/v1/scopes/o_1", nil)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tokValue))
// Add values for authn/authz checking
requestInfo := authpb.RequestInfo{
Path: req.URL.Path,
Method: req.Method,
EventId: "test-event-id",
TraceId: "test-trace-id",
}
requestInfo.PublicId, requestInfo.EncryptedToken, requestInfo.TokenFormat = auth.GetTokenFromRequest(context.TODO(), kmsCache, req)
requestInfo.Ticket = gwTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy
marshalledRequestInfo, err := proto.Marshal(&requestInfo)
require.NoError(t, err)
md := metadata.Pairs(requestInfoMdKey, base58.FastBase58Encoding(marshalledRequestInfo))
mdCtx := metadata.NewIncomingContext(factoryCtx, md)
md, ok := metadata.FromIncomingContext(mdCtx)
require.True(t, ok)
require.NotNil(t, md)
return mdCtx
}
c := event.TestEventerConfig(t, "Test_streamCtxInterceptor", event.TestWithAuditSink(t), event.TestWithObservationSink(t))
testLock := &sync.Mutex{}
testLogger := hclog.New(&hclog.LoggerOptions{
Mutex: testLock,
Name: "test",
})
testEventer, err := event.NewEventer(testLogger, testLock, "Test_streamCtxInterceptor", c.EventerConfig)
require.NoError(t, err)
tests := []struct {
name string
requestCtx context.Context
iamRepoFn common.IamRepoFactory
authTokenRepoFn common.AuthTokenRepoFactory
serversRepoFn common.ServersRepoFactory
kms *kms.Kms
eventer *event.Eventer
ticket string
wantFactoryErr bool
wantFactoryErrMatch *errors.Template
wantFactoryErrContains string
wantRequestErr bool
wantRequestErrMatch *errors.Template
wantRequestErrContains string
}{
{
name: "missing-iam-repo",
requestCtx: newReqCtx(validGatewayTicket),
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
wantFactoryErr: true,
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
wantFactoryErrContains: "missing iam repo",
},
{
name: "missing-at-repo",
requestCtx: newReqCtx(validGatewayTicket),
iamRepoFn: iamRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
wantFactoryErr: true,
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
wantFactoryErrContains: "missing auth token repo",
},
{
name: "missing-servers-repo",
requestCtx: newReqCtx(validGatewayTicket),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
kms: kmsCache,
eventer: testEventer,
wantFactoryErr: true,
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
wantFactoryErrContains: "missing server repo function",
},
{
name: "missing-kms",
requestCtx: newReqCtx(validGatewayTicket),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
eventer: testEventer,
wantFactoryErr: true,
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
wantFactoryErrContains: "missing kms",
},
{
name: "missing-eventer",
requestCtx: newReqCtx(validGatewayTicket),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
wantFactoryErr: true,
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
wantFactoryErrContains: "missing kms",
},
{
name: "missing-factory-ticket",
requestCtx: newReqCtx(validGatewayTicket),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
wantFactoryErr: true,
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
wantFactoryErrContains: "missing ticket",
},
{
name: "missing-metadata",
requestCtx: context.Background(),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: validGatewayTicket,
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "No metadata",
},
{
name: "too-many-request-info-metadata",
requestCtx: func() context.Context {
md := metadata.Pairs(requestInfoMdKey, "first", requestInfoMdKey, "second")
return metadata.NewIncomingContext(context.Background(), md)
}(),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: validGatewayTicket,
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "expected 1 value",
},
{
name: "request-info-metadata-not-encoded",
requestCtx: func() context.Context {
md := metadata.Pairs(requestInfoMdKey, "hello")
return metadata.NewIncomingContext(context.Background(), md)
}(),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: validGatewayTicket,
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "unable to decode request info",
},
{
name: "request-info-metadata-not-proto",
requestCtx: func() context.Context {
md := metadata.Pairs(requestInfoMdKey, base58.FastBase58Encoding([]byte("hello")))
return metadata.NewIncomingContext(context.Background(), md)
}(),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: validGatewayTicket,
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "unable to unmarshal request info",
},
{
name: "missing-request-ticket",
requestCtx: newReqCtx(""),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: "validGatewayTicket",
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "Invalid context (missing ticket)",
},
{
name: "bad-ticket",
requestCtx: newReqCtx("bad-ticket"),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: "validGatewayTicket",
wantRequestErr: true,
wantRequestErrMatch: errors.T(errors.Internal),
wantRequestErrContains: "Invalid context (bad ticket)",
},
{
name: "valid",
requestCtx: newReqCtx(validGatewayTicket),
iamRepoFn: iamRepoFn,
authTokenRepoFn: atRepoFn,
serversRepoFn: serversRepoFn,
kms: kmsCache,
eventer: testEventer,
ticket: validGatewayTicket,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
interceptor, err := requestCtxStreamInterceptor(factoryCtx, tt.iamRepoFn, tt.authTokenRepoFn, tt.serversRepoFn, nil, nil, nil, tt.kms, tt.ticket, tt.eventer)
if tt.wantFactoryErr {
require.Error(err)
assert.Nil(interceptor)
if tt.wantFactoryErrMatch != nil {
assert.Truef(errors.Match(tt.wantFactoryErrMatch, err), "want err code: %q got: %q", tt.wantFactoryErrMatch.Code, err)
}
if tt.wantFactoryErrContains != "" {
assert.Contains(err.Error(), tt.wantFactoryErrContains)
}
return
}
require.NoError(err)
assert.NotNil(interceptor)
info := &grpc.StreamServerInfo{
FullMethod: "FakeMethod",
IsClientStream: true,
}
var hdCtx context.Context
hd := func(srv interface{}, stream grpc.ServerStream) error {
hdCtx = stream.Context()
return nil
}
m := &streamMock{ctx: tt.requestCtx}
err = interceptor(nil, m, info, hd)
if tt.wantRequestErr {
require.Error(err)
if tt.wantRequestErrMatch != nil {
assert.Truef(errors.Match(tt.wantRequestErrMatch, err), "want err code: %q got: %q", tt.wantRequestErrMatch.Code, err)
}
if tt.wantRequestErrContains != "" {
assert.Contains(err.Error(), tt.wantRequestErrContains)
}
return
}
require.NoError(err)
// Use resource.Scope here but resource type shouldn't matter since we're only validating that auth.Verify
// can use the 'context' with additional ctx.Value from the interceptor
verifyResults := auth.Verify(hdCtx.(context.Context), resource.Scope)
assert.NotEmpty(verifyResults)
})
}
}
type streamMock struct {
grpc.ServerStream
ctx context.Context
}
func (m *streamMock) Context() context.Context {
return m.ctx
}
func (m *streamMock) Send(req *httpbody.HttpBody) error {
panic("send not implemented")
}
func (m *streamMock) RecvToClient() (*httpbody.HttpBody, error) {
panic("recv not implemented")
}
func Test_aliasResolutionInterceptor(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, wrapper)
aliasRepoFn := func() (*alias.Repository, error) {
return alias.NewRepository(context.Background(), rw, rw, kmsCache)
}
_, proj := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), "test-target")
al := talias.TestAlias(t, rw, "test-alias.example", talias.WithDestinationId(tar.GetPublicId()))
alWithoutDest := talias.TestAlias(t, rw, "no-destination.alias")
interceptor := aliasResolutionInterceptor(ctx, aliasRepoFn)
require.NotNil(t, interceptor)
returnCtxHandler := func(ctx context.Context, req any) (any, error) {
return ctx, nil
}
info := &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}
cases := []struct {
name string
req proto.Message
wantModifiedReq proto.Message
wantAlias *alias.Alias
errorIs error
}{
{
name: "non aliasable request",
req: &pbs.GetAccountRequest{Id: "test"},
wantModifiedReq: &pbs.GetAccountRequest{Id: "test"},
},
{
name: "non aliasable request with valid alias",
req: &pbs.GetAccountRequest{Id: al.GetValue()},
wantModifiedReq: &pbs.GetAccountRequest{Id: al.GetValue()},
},
{
name: "aliasable request with id",
req: &pbs.GetTargetRequest{Id: tar.GetPublicId()},
wantModifiedReq: &pbs.GetTargetRequest{Id: tar.GetPublicId()},
},
{
name: "aliasable request with alias",
req: &pbs.GetTargetRequest{Id: al.GetValue()},
wantModifiedReq: &pbs.GetTargetRequest{Id: tar.GetPublicId()},
wantAlias: &alias.Alias{
PublicId: al.GetPublicId(),
Value: al.GetValue(),
DestinationId: tar.GetPublicId(),
},
},
{
name: "aliasable request with unknown alias",
req: &pbs.GetTargetRequest{Id: "not.a.registered.alias"},
wantModifiedReq: &pbs.GetTargetRequest{Id: "not.a.registered.alias"},
errorIs: handlers.NotFoundError(),
},
{
name: "aliasable request with destinationless alias",
req: &pbs.GetTargetRequest{Id: alWithoutDest.GetValue()},
wantModifiedReq: &pbs.GetTargetRequest{Id: alWithoutDest.GetValue()},
errorIs: handlers.NotFoundError(),
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
req := proto.Clone(tc.req)
retCtx, err := interceptor(ctx, req, info, returnCtxHandler)
assert.Empty(t, cmp.Diff(tc.wantModifiedReq, req, protocmp.Transform()))
if tc.errorIs != nil {
require.Error(t, err)
require.ErrorIs(t, err, tc.errorIs)
return
}
ctxAlias := alias.FromContext(retCtx.(context.Context))
if tc.wantAlias == nil {
require.Nil(t, ctxAlias)
return
}
require.NotNil(t, ctxAlias)
assert.EqualValues(t, tc.wantAlias, ctxAlias)
})
}
}
func Test_errorInterceptor(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
factoryCtx context.Context
wantRequestErr bool
wantDomainErr bool
wantApiErr bool
}{
{
name: "domain-error",
factoryCtx: ctx,
wantRequestErr: true,
wantDomainErr: true,
},
{
name: "api-error",
factoryCtx: ctx,
wantRequestErr: true,
wantApiErr: true,
},
{
name: "success",
factoryCtx: ctx,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
require.Falsef(tt.wantApiErr && tt.wantDomainErr, "tests can't require both wantApiErr and wantDomainErr")
errInterceptor := errorInterceptor(tt.factoryCtx)
client := startTestGreeterService(t, &testGreeter{}, errInterceptor)
var header metadata.MD
resp, err := client.SayHello(
context.Background(),
&interceptor.SayHelloRequest{Name: tt.name},
grpc.Header(&header),
)
if tt.wantRequestErr {
require.Error(err)
domainHdr := header.Get(domainErrHeader)
apiErrHdr := header.Get(apiErrHeader)
require.Error(err)
assert.Nil(resp)
if tt.wantDomainErr {
require.Len(domainHdr, 1)
assert.Len(apiErrHdr, 0)
decoded, err := base58.FastBase58Decoding(domainHdr[0])
require.NoError(err)
var pbErr pberrors.Err
err = proto.Unmarshal(decoded, &pbErr)
require.NoError(err)
}
if tt.wantApiErr {
require.Len(apiErrHdr, 1)
assert.Len(domainHdr, 0)
decoded, err := base58.FastBase58Decoding(apiErrHdr[0])
require.NoError(err)
var pbErr pberrors.ApiError
err = proto.Unmarshal(decoded, &pbErr)
require.NoError(err)
}
return
}
require.NoError(err)
})
}
}
func Test_statusCodeInterceptor(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
wantStatusCode int
wantErr bool
}{
{
name: "nil-nil",
wantStatusCode: http.StatusNoContent,
},
{
name: "nil-err",
wantErr: true,
},
{
name: "hello",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
statusInterceptor := statusCodeInterceptor(ctx)
client := startTestGreeterService(t, &testGreeter{}, statusInterceptor)
var header metadata.MD
_, err := client.SayHello(
context.Background(),
&interceptor.SayHelloRequest{Name: tt.name},
grpc.Header(&header),
)
if tt.wantErr {
assert.Error(err)
}
statusHdr := header.Get(handlers.StatusCodeHeader)
if tt.wantStatusCode > 0 {
require.Len(statusHdr, 1)
code, err := strconv.Atoi(statusHdr[0])
require.NoError(err)
assert.Equal(tt.wantStatusCode, code)
} else {
require.Len(statusHdr, 0)
}
})
}
}
func Test_parseUserAgents(t *testing.T) {
t.Parallel()
tests := []struct {
name string
rawUserAgent string
expected []*event.UserAgent
}{
{
name: "valid single user-agent",
rawUserAgent: "Boundary-client-agent/0.1.4",
expected: []*event.UserAgent{
{
Product: "Boundary-client-agent",
ProductVersion: "0.1.4",
},
},
},
{
name: "multiple valid agents with comments",
rawUserAgent: "Boundary-client-agent/0.1.4 (foo; bar); AnotherApp/2.0.0 (baz )",
expected: []*event.UserAgent{
{
Product: "Boundary-client-agent",
ProductVersion: "0.1.4",
Comments: []string{"foo", "bar"},
},
{
Product: "AnotherApp",
ProductVersion: "2.0.0",
Comments: []string{"baz"},
},
},
},
{
name: "complex but valid user agents",
rawUserAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36 surveyon/2.9.5 (iPhone; CPU iPhone OS 12_5_7 like Mac OS X)",
expected: []*event.UserAgent{
{
Product: "Mozilla",
ProductVersion: "5.0",
Comments: []string{"Macintosh", "Intel Mac OS X 10_15_7"},
},
{
Product: "AppleWebKit",
ProductVersion: "537.36",
Comments: []string{"KHTML, like Gecko"},
},
{
Product: "Chrome",
ProductVersion: "87.0.4280.88",
},
{
Product: "Safari",
ProductVersion: "537.36",
},
{
Product: "surveyon",
ProductVersion: "2.9.5",
Comments: []string{"iPhone", "CPU iPhone OS 12_5_7 like Mac OS X"},
},
},
},
{
name: "invalid client-agent version format (starts with 'v')",
rawUserAgent: "Boundary-client-agent/v0.1.4",
expected: nil,
},
{
name: "invalid client-agent version format (non-semver)",
rawUserAgent: "Boundary-client-agent/0.1.x",
expected: nil,
},
{
name: "empty user-agent",
rawUserAgent: "",
expected: nil,
},
{
name: "valid non client-agent user-agent",
rawUserAgent: "SomeOtherApp/v1.2.3",
expected: []*event.UserAgent{
{
Product: "SomeOtherApp",
ProductVersion: "v1.2.3",
},
},
},
{
name: "mixed valid and invalid agents",
rawUserAgent: "Boundary-client-agent/0.1.4 NoVersionApp SomeOtherApp/",
expected: []*event.UserAgent{
{
Product: "Boundary-client-agent",
ProductVersion: "0.1.4",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := parseUserAgents(tt.rawUserAgent)
assert.ElementsMatch(t, tt.expected, result)
})
}
}
func Test_workerRequestInfoInterceptor(t *testing.T) {
factoryCtx := context.Background()
requestCtx := context.Background()
returnCtxHandler := func(ctx context.Context, req any) (any, error) {
return ctx, nil
}
c := event.TestEventerConfig(t, "Test_unaryCtxInterceptor", event.TestWithAuditSink(t), event.TestWithObservationSink(t))
testLock := &sync.Mutex{}
testLogger := hclog.New(&hclog.LoggerOptions{
Mutex: testLock,
Name: "test",
})
testEventer, err := event.NewEventer(testLogger, testLock, "Test_unaryCtxInterceptor", c.EventerConfig)
require.NoError(t, err)
tests := []struct {
name string
requestCtx context.Context
eventer *event.Eventer
wantFactoryErr bool
wantFactoryErrMatch *errors.Template
wantFactoryErrContains string
wantRequestErr bool
wantRequestErrMatch *errors.Template
wantRequestErrContains string
}{
{
name: "missing-eventer",
requestCtx: requestCtx,
wantFactoryErr: true,
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
wantFactoryErrContains: "missing eventer",
},
{
name: "valid",
requestCtx: requestCtx,
eventer: testEventer,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
interceptor, err := workerRequestInfoInterceptor(factoryCtx, tt.eventer)
if tt.wantFactoryErr {
require.Error(err)
assert.Nil(interceptor)
if tt.wantFactoryErrMatch != nil {
assert.Truef(errors.Match(tt.wantFactoryErrMatch, err), "want err code: %q got: %q", tt.wantFactoryErrMatch.Code, err)
}
if tt.wantFactoryErrContains != "" {
assert.Contains(err.Error(), tt.wantFactoryErrContains)
}
return
}
require.NoError(err)
assert.NotNil(interceptor)
info := &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}
retCtx, err := interceptor(tt.requestCtx, nil, info, returnCtxHandler)
if tt.wantRequestErr {
require.Error(err)
assert.Nil(retCtx)
if tt.wantRequestErrMatch != nil {
assert.Truef(errors.Match(tt.wantRequestErrMatch, err), "want err code: %q got: %q", tt.wantRequestErrMatch.Code, err)
}
if tt.wantRequestErrContains != "" {
assert.Contains(err.Error(), tt.wantRequestErrContains)
}
return
}
require.NoError(err)
requestInfo, found := event.RequestInfoFromContext(retCtx.(context.Context))
require.True(found)
assert.NotNil(requestInfo)
assert.NotEmpty(requestInfo.Id)
assert.NotEmpty(requestInfo.EventId)
assert.Equal("FakeMethod", requestInfo.Method)
eventer, found := event.EventerFromContext(retCtx.(context.Context))
require.True(found)
assert.NotNil(eventer)
})
}
}
type testGreeter struct {
interceptor.UnimplementedGreeterServiceServer
}
func (g *testGreeter) SayHello(ctx context.Context, req *interceptor.SayHelloRequest) (*interceptor.SayHelloResponse, error) {
const op = "SayHello"
switch req.GetName() {
case "api-error":
return &interceptor.SayHelloResponse{}, &handlers.ApiError{
Status: http.StatusInternalServerError,
Inner: &pb_api.Error{
Kind: "internal",
Op: op,
Message: "api error msg",
},
}
case "domain-error":
return &interceptor.SayHelloResponse{Message: "hello"}, errors.New(ctx, errors.Internal, op, "domain error msg")
case "nil-nil":
return nil, nil
case "nil-err":
return nil, errors.New(ctx, errors.Internal, op, "nil response error msg")
default:
return &interceptor.SayHelloResponse{Message: "hello"}, nil
}
}
func Test_correlationIdInterceptor(t *testing.T) {
interceptor := correlationIdInterceptor(context.Background())
require.NotNil(t, interceptor)
corId, err := uuid.GenerateUUID()
require.NoError(t, err)
returnCtxHandler := func(ctx context.Context, req any) (any, error) {
return ctx, nil
}
cases := []struct {
name string
ctx context.Context
wantCorId string
wantErr bool
wantErrStr string
}{
{
name: "no metadata",
ctx: context.Background(),
wantErr: true,
wantErrStr: "controller.correlationIdInterceptor: no metadata",
},
{
name: "no correlation id",
ctx: metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{
"not-correlation-id": "this is not a correlation id",
})),
wantErr: true,
wantErrStr: "controller.correlationIdInterceptor: missing correlation id metadata",
},
{
name: "too many correlation ids",
ctx: metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{
globals.CorrelationIdKey: corId,
"x-Correlation-id": corId, // metadata.New does a toLower so this is an easy way to add multiple of same key
})),
wantErr: true,
wantErrStr: "controller.correlationIdInterceptor: expected 1 value for x-correlation-id metadata and got 2",
},
{
name: "invalid correlation id",
ctx: metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{
globals.CorrelationIdKey: "this is wrong",
})),
wantErr: true,
wantErrStr: "controller.correlationIdInterceptor: failed to validated correlation id",
},
{
name: "valid correlation id",
ctx: metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{
globals.CorrelationIdKey: corId,
})),
wantCorId: corId,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
req := &pbs.GetAccountRequest{Id: "test"}
retCtx, err := interceptor(tc.ctx, req, nil, returnCtxHandler)
if tc.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), tc.wantErrStr)
return
}
assert.NoError(t, err)
corId, ok := event.CorrelationIdFromContext(retCtx.(context.Context))
require.True(t, ok)
assert.Equal(t, tc.wantCorId, corId)
})
}
}