|
|
|
|
@ -30,6 +30,7 @@ import (
|
|
|
|
|
"github.com/mr-tron/base58"
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
|
"google.golang.org/genproto/googleapis/api/httpbody"
|
|
|
|
|
"google.golang.org/grpc"
|
|
|
|
|
"google.golang.org/grpc/metadata"
|
|
|
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
|
@ -314,7 +315,7 @@ func Test_requestCtxInterceptor(t *testing.T) {
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
assert, require := assert.New(t), require.New(t)
|
|
|
|
|
interceptor, err := requestCtxInterceptor(factoryCtx, tt.iamRepoFn, tt.authTokenRepoFn, tt.serversRepoFn, nil, nil, nil, tt.kms, tt.ticket, tt.eventer)
|
|
|
|
|
interceptor, err := requestCtxUnaryInterceptor(factoryCtx, tt.iamRepoFn, tt.authTokenRepoFn, tt.serversRepoFn, nil, nil, nil, tt.kms, tt.ticket, tt.eventer)
|
|
|
|
|
if tt.wantFactoryErr {
|
|
|
|
|
require.Error(err)
|
|
|
|
|
assert.Nil(interceptor)
|
|
|
|
|
@ -351,6 +352,310 @@ func Test_requestCtxInterceptor(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_streamCtxInterceptor(t *testing.T) {
|
|
|
|
|
t.Parallel()
|
|
|
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
|
|
|
rw := db.New(conn)
|
|
|
|
|
wrapper := db.TestWrapper(t)
|
|
|
|
|
kmsCache := kms.TestKms(t, conn, wrapper)
|
|
|
|
|
iamRepo := iam.TestRepo(t, conn, wrapper)
|
|
|
|
|
|
|
|
|
|
iamRepoFn := func() (*iam.Repository, error) {
|
|
|
|
|
return iamRepo, nil
|
|
|
|
|
}
|
|
|
|
|
atRepoFn := func() (*authtoken.Repository, error) {
|
|
|
|
|
return authtoken.NewRepository(rw, rw, kmsCache)
|
|
|
|
|
}
|
|
|
|
|
serversRepoFn := func() (*server.Repository, error) {
|
|
|
|
|
return server.NewRepository(rw, rw, kmsCache)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
validGatewayTicket := "valid-ticket"
|
|
|
|
|
|
|
|
|
|
o, _ := iam.TestScopes(t, iamRepo)
|
|
|
|
|
at := authtoken.TestAuthToken(t, conn, kmsCache, o.GetPublicId())
|
|
|
|
|
encToken, err := authtoken.EncryptToken(context.Background(), kmsCache, o.GetPublicId(), at.GetPublicId(), at.GetToken())
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
tokValue := at.GetPublicId() + "_" + encToken
|
|
|
|
|
|
|
|
|
|
newReqCtx := func(gwTicket string) context.Context {
|
|
|
|
|
req := httptest.NewRequest("GET", "http://127.0.0.1/v1/scopes/o_1", nil)
|
|
|
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tokValue))
|
|
|
|
|
// Add values for authn/authz checking
|
|
|
|
|
requestInfo := authpb.RequestInfo{
|
|
|
|
|
Path: req.URL.Path,
|
|
|
|
|
Method: req.Method,
|
|
|
|
|
EventId: "test-event-id",
|
|
|
|
|
TraceId: "test-trace-id",
|
|
|
|
|
}
|
|
|
|
|
requestInfo.PublicId, requestInfo.EncryptedToken, requestInfo.TokenFormat = auth.GetTokenFromRequest(context.TODO(), kmsCache, req)
|
|
|
|
|
requestInfo.Ticket = gwTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy
|
|
|
|
|
marshalledRequestInfo, err := proto.Marshal(&requestInfo)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
md := metadata.Pairs(requestInfoMdKey, base58.FastBase58Encoding(marshalledRequestInfo))
|
|
|
|
|
mdCtx := metadata.NewIncomingContext(context.Background(), md)
|
|
|
|
|
|
|
|
|
|
md, ok := metadata.FromIncomingContext(mdCtx)
|
|
|
|
|
require.True(t, ok)
|
|
|
|
|
require.NotNil(t, md)
|
|
|
|
|
|
|
|
|
|
return mdCtx
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
factoryCtx := context.Background()
|
|
|
|
|
|
|
|
|
|
c := event.TestEventerConfig(t, "Test_requestCtxInterceptor", event.TestWithAuditSink(t), event.TestWithObservationSink(t))
|
|
|
|
|
testLock := &sync.Mutex{}
|
|
|
|
|
testLogger := hclog.New(&hclog.LoggerOptions{
|
|
|
|
|
Mutex: testLock,
|
|
|
|
|
Name: "test",
|
|
|
|
|
})
|
|
|
|
|
testEventer, err := event.NewEventer(testLogger, testLock, "Test_requestCtxInterceptor", c.EventerConfig)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
tests := []struct {
|
|
|
|
|
name string
|
|
|
|
|
requestCtx context.Context
|
|
|
|
|
iamRepoFn common.IamRepoFactory
|
|
|
|
|
authTokenRepoFn common.AuthTokenRepoFactory
|
|
|
|
|
serversRepoFn common.ServersRepoFactory
|
|
|
|
|
kms *kms.Kms
|
|
|
|
|
eventer *event.Eventer
|
|
|
|
|
ticket string
|
|
|
|
|
wantFactoryErr bool
|
|
|
|
|
wantFactoryErrMatch *errors.Template
|
|
|
|
|
wantFactoryErrContains string
|
|
|
|
|
wantRequestErr bool
|
|
|
|
|
wantRequestErrMatch *errors.Template
|
|
|
|
|
wantRequestErrContains string
|
|
|
|
|
}{
|
|
|
|
|
{
|
|
|
|
|
name: "missing-iam-repo",
|
|
|
|
|
requestCtx: newReqCtx(validGatewayTicket),
|
|
|
|
|
authTokenRepoFn: atRepoFn,
|
|
|
|
|
serversRepoFn: serversRepoFn,
|
|
|
|
|
kms: kmsCache,
|
|
|
|
|
eventer: testEventer,
|
|
|
|
|
wantFactoryErr: true,
|
|
|
|
|
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
|
|
|
|
|
wantFactoryErrContains: "missing iam repo",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "missing-at-repo",
|
|
|
|
|
requestCtx: newReqCtx(validGatewayTicket),
|
|
|
|
|
iamRepoFn: iamRepoFn,
|
|
|
|
|
serversRepoFn: serversRepoFn,
|
|
|
|
|
kms: kmsCache,
|
|
|
|
|
eventer: testEventer,
|
|
|
|
|
wantFactoryErr: true,
|
|
|
|
|
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
|
|
|
|
|
wantFactoryErrContains: "missing auth token repo",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "missing-servers-repo",
|
|
|
|
|
requestCtx: newReqCtx(validGatewayTicket),
|
|
|
|
|
iamRepoFn: iamRepoFn,
|
|
|
|
|
authTokenRepoFn: atRepoFn,
|
|
|
|
|
kms: kmsCache,
|
|
|
|
|
eventer: testEventer,
|
|
|
|
|
wantFactoryErr: true,
|
|
|
|
|
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
|
|
|
|
|
wantFactoryErrContains: "missing server repo function",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "missing-kms",
|
|
|
|
|
requestCtx: newReqCtx(validGatewayTicket),
|
|
|
|
|
iamRepoFn: iamRepoFn,
|
|
|
|
|
authTokenRepoFn: atRepoFn,
|
|
|
|
|
serversRepoFn: serversRepoFn,
|
|
|
|
|
eventer: testEventer,
|
|
|
|
|
wantFactoryErr: true,
|
|
|
|
|
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
|
|
|
|
|
wantFactoryErrContains: "missing kms",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "missing-eventer",
|
|
|
|
|
requestCtx: newReqCtx(validGatewayTicket),
|
|
|
|
|
iamRepoFn: iamRepoFn,
|
|
|
|
|
authTokenRepoFn: atRepoFn,
|
|
|
|
|
serversRepoFn: serversRepoFn,
|
|
|
|
|
wantFactoryErr: true,
|
|
|
|
|
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
|
|
|
|
|
wantFactoryErrContains: "missing kms",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "missing-factory-ticket",
|
|
|
|
|
requestCtx: newReqCtx(validGatewayTicket),
|
|
|
|
|
iamRepoFn: iamRepoFn,
|
|
|
|
|
authTokenRepoFn: atRepoFn,
|
|
|
|
|
serversRepoFn: serversRepoFn,
|
|
|
|
|
kms: kmsCache,
|
|
|
|
|
eventer: testEventer,
|
|
|
|
|
wantFactoryErr: true,
|
|
|
|
|
wantFactoryErrMatch: errors.T(errors.InvalidParameter),
|
|
|
|
|
wantFactoryErrContains: "missing ticket",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "missing-metadata",
|
|
|
|
|
requestCtx: context.Background(),
|
|
|
|
|
iamRepoFn: iamRepoFn,
|
|
|
|
|
authTokenRepoFn: atRepoFn,
|
|
|
|
|
serversRepoFn: serversRepoFn,
|
|
|
|
|
kms: kmsCache,
|
|
|
|
|
eventer: testEventer,
|
|
|
|
|
ticket: validGatewayTicket,
|
|
|
|
|
wantRequestErr: true,
|
|
|
|
|
wantRequestErrMatch: errors.T(errors.Internal),
|
|
|
|
|
wantRequestErrContains: "No metadata",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "too-many-request-info-metadata",
|
|
|
|
|
requestCtx: func() context.Context {
|
|
|
|
|
md := metadata.Pairs(requestInfoMdKey, "first", requestInfoMdKey, "second")
|
|
|
|
|
return metadata.NewIncomingContext(context.Background(), md)
|
|
|
|
|
}(),
|
|
|
|
|
iamRepoFn: iamRepoFn,
|
|
|
|
|
authTokenRepoFn: atRepoFn,
|
|
|
|
|
serversRepoFn: serversRepoFn,
|
|
|
|
|
kms: kmsCache,
|
|
|
|
|
eventer: testEventer,
|
|
|
|
|
ticket: validGatewayTicket,
|
|
|
|
|
wantRequestErr: true,
|
|
|
|
|
wantRequestErrMatch: errors.T(errors.Internal),
|
|
|
|
|
wantRequestErrContains: "expected 1 value",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "request-info-metadata-not-encoded",
|
|
|
|
|
requestCtx: func() context.Context {
|
|
|
|
|
md := metadata.Pairs(requestInfoMdKey, "hello")
|
|
|
|
|
return metadata.NewIncomingContext(context.Background(), md)
|
|
|
|
|
}(),
|
|
|
|
|
iamRepoFn: iamRepoFn,
|
|
|
|
|
authTokenRepoFn: atRepoFn,
|
|
|
|
|
serversRepoFn: serversRepoFn,
|
|
|
|
|
kms: kmsCache,
|
|
|
|
|
eventer: testEventer,
|
|
|
|
|
ticket: validGatewayTicket,
|
|
|
|
|
wantRequestErr: true,
|
|
|
|
|
wantRequestErrMatch: errors.T(errors.Internal),
|
|
|
|
|
wantRequestErrContains: "unable to decode request info",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "request-info-metadata-not-proto",
|
|
|
|
|
requestCtx: func() context.Context {
|
|
|
|
|
md := metadata.Pairs(requestInfoMdKey, base58.FastBase58Encoding([]byte("hello")))
|
|
|
|
|
return metadata.NewIncomingContext(context.Background(), md)
|
|
|
|
|
}(),
|
|
|
|
|
iamRepoFn: iamRepoFn,
|
|
|
|
|
authTokenRepoFn: atRepoFn,
|
|
|
|
|
serversRepoFn: serversRepoFn,
|
|
|
|
|
kms: kmsCache,
|
|
|
|
|
eventer: testEventer,
|
|
|
|
|
ticket: validGatewayTicket,
|
|
|
|
|
wantRequestErr: true,
|
|
|
|
|
wantRequestErrMatch: errors.T(errors.Internal),
|
|
|
|
|
wantRequestErrContains: "unable to unmarshal request info",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "missing-request-ticket",
|
|
|
|
|
requestCtx: newReqCtx(""),
|
|
|
|
|
iamRepoFn: iamRepoFn,
|
|
|
|
|
authTokenRepoFn: atRepoFn,
|
|
|
|
|
serversRepoFn: serversRepoFn,
|
|
|
|
|
kms: kmsCache,
|
|
|
|
|
eventer: testEventer,
|
|
|
|
|
ticket: "validGatewayTicket",
|
|
|
|
|
wantRequestErr: true,
|
|
|
|
|
wantRequestErrMatch: errors.T(errors.Internal),
|
|
|
|
|
wantRequestErrContains: "Invalid context (missing ticket)",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "bad-ticket",
|
|
|
|
|
requestCtx: newReqCtx("bad-ticket"),
|
|
|
|
|
iamRepoFn: iamRepoFn,
|
|
|
|
|
authTokenRepoFn: atRepoFn,
|
|
|
|
|
serversRepoFn: serversRepoFn,
|
|
|
|
|
kms: kmsCache,
|
|
|
|
|
eventer: testEventer,
|
|
|
|
|
ticket: "validGatewayTicket",
|
|
|
|
|
wantRequestErr: true,
|
|
|
|
|
wantRequestErrMatch: errors.T(errors.Internal),
|
|
|
|
|
wantRequestErrContains: "Invalid context (bad ticket)",
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
name: "valid",
|
|
|
|
|
requestCtx: newReqCtx(validGatewayTicket),
|
|
|
|
|
iamRepoFn: iamRepoFn,
|
|
|
|
|
authTokenRepoFn: atRepoFn,
|
|
|
|
|
serversRepoFn: serversRepoFn,
|
|
|
|
|
kms: kmsCache,
|
|
|
|
|
eventer: testEventer,
|
|
|
|
|
ticket: validGatewayTicket,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
|
assert, require := assert.New(t), require.New(t)
|
|
|
|
|
interceptor, err := requestCtxStreamInterceptor(factoryCtx, tt.iamRepoFn, tt.authTokenRepoFn, tt.serversRepoFn, nil, nil, nil, tt.kms, tt.ticket, tt.eventer)
|
|
|
|
|
if tt.wantFactoryErr {
|
|
|
|
|
require.Error(err)
|
|
|
|
|
assert.Nil(interceptor)
|
|
|
|
|
if tt.wantFactoryErrMatch != nil {
|
|
|
|
|
assert.Truef(errors.Match(tt.wantFactoryErrMatch, err), "want err code: %q got: %q", tt.wantFactoryErrMatch.Code, err)
|
|
|
|
|
}
|
|
|
|
|
if tt.wantFactoryErrContains != "" {
|
|
|
|
|
assert.Contains(err.Error(), tt.wantFactoryErrContains)
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(err)
|
|
|
|
|
assert.NotNil(interceptor)
|
|
|
|
|
|
|
|
|
|
info := &grpc.StreamServerInfo{
|
|
|
|
|
FullMethod: "FakeMethod",
|
|
|
|
|
IsClientStream: true,
|
|
|
|
|
}
|
|
|
|
|
var hdCtx context.Context
|
|
|
|
|
|
|
|
|
|
hd := func(srv interface{}, stream grpc.ServerStream) error {
|
|
|
|
|
hdCtx = stream.Context()
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
m := &streamMock{ctx: tt.requestCtx}
|
|
|
|
|
err = interceptor(nil, m, info, hd)
|
|
|
|
|
if tt.wantRequestErr {
|
|
|
|
|
require.Error(err)
|
|
|
|
|
if tt.wantRequestErrMatch != nil {
|
|
|
|
|
assert.Truef(errors.Match(tt.wantRequestErrMatch, err), "want err code: %q got: %q", tt.wantRequestErrMatch.Code, err)
|
|
|
|
|
}
|
|
|
|
|
if tt.wantRequestErrContains != "" {
|
|
|
|
|
assert.Contains(err.Error(), tt.wantRequestErrContains)
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
require.NoError(err)
|
|
|
|
|
verifyResults := auth.Verify(hdCtx.(context.Context))
|
|
|
|
|
assert.NotEmpty(verifyResults)
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type streamMock struct {
|
|
|
|
|
grpc.ServerStream
|
|
|
|
|
ctx context.Context
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m *streamMock) Context() context.Context {
|
|
|
|
|
return m.ctx
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m *streamMock) Send(req *httpbody.HttpBody) error {
|
|
|
|
|
panic("send not implemented")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m *streamMock) RecvToClient() (*httpbody.HttpBody, error) {
|
|
|
|
|
panic("recv not implemented")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func Test_errorInterceptor(t *testing.T) {
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
tests := []struct {
|
|
|
|
|
|