diff --git a/internal/daemon/controller/gateway.go b/internal/daemon/controller/gateway.go index 8d9d95c7b3..bbd35581d2 100644 --- a/internal/daemon/controller/gateway.go +++ b/internal/daemon/controller/gateway.go @@ -81,13 +81,34 @@ func newGrpcServer( if err != nil { return nil, "", errors.Wrap(ctx, err, op, errors.WithMsg("unable to generate gateway ticket")) } - requestCtxInterceptor, err := requestCtxInterceptor(ctx, iamRepoFn, authTokenRepoFn, serversRepoFn, passwordAuthRepoFn, oidcAuthRepoFn, ldapAuthRepoFn, kms, ticket, eventer) + requestCtxInterceptor, err := requestCtxUnaryInterceptor(ctx, iamRepoFn, authTokenRepoFn, serversRepoFn, passwordAuthRepoFn, oidcAuthRepoFn, ldapAuthRepoFn, kms, ticket, eventer) + if err != nil { + return nil, "", err + } + + streamCtxInterceptor, err := requestCtxStreamInterceptor( + ctx, + iamRepoFn, + authTokenRepoFn, + serversRepoFn, + passwordAuthRepoFn, + oidcAuthRepoFn, + ldapAuthRepoFn, + kms, + ticket, + eventer, + ) if err != nil { return nil, "", err } return grpc.NewServer( grpc.MaxRecvMsgSize(math.MaxInt32), grpc.MaxSendMsgSize(math.MaxInt32), + grpc.StreamInterceptor( + grpc_middleware.ChainStreamServer( + streamCtxInterceptor, + ), + ), grpc.UnaryInterceptor( grpc_middleware.ChainUnaryServer( requestCtxInterceptor, // populated requestInfo from headers into the request ctx diff --git a/internal/daemon/controller/interceptor.go b/internal/daemon/controller/interceptor.go index 4182dc78b4..770a257061 100644 --- a/internal/daemon/controller/interceptor.go +++ b/internal/daemon/controller/interceptor.go @@ -42,11 +42,25 @@ const ( apiErrHeader = "x-api-err" ) -// requestCtxInterceptor creates an unary server interceptor that pulls grpc -// metadata into a ctx for the request. The metadata must be set in an upstream -// http handler/middleware by marshalling a RequestInfo protobuf into the -// requestInfoMdKey header (see: controller.wrapHandlerWithCommonFuncs). -func requestCtxInterceptor( +// customContextServerStream wraps the grpc.ServerStream interface and lets us +// set a custom context +type customContextServerStream struct { + grpc.ServerStream + customContext context.Context +} + +func (c *customContextServerStream) Context() context.Context { + if c.customContext != nil { + return c.customContext + } + return c.ServerStream.Context() +} + +// requestCtxUnaryInterceptor creates an unary server interceptor that pulls +// grpc metadata into a ctx for the request. The metadata must be set in an +// upstream http handler/middleware by marshalling a RequestInfo protobuf into +// the requestInfoMdKey header (see: controller.wrapHandlerWithCommonFuncs). +func requestCtxUnaryInterceptor( ctx context.Context, iamRepoFn common.IamRepoFactory, authTokenRepoFn common.AuthTokenRepoFactory, @@ -58,96 +72,209 @@ func requestCtxInterceptor( ticket string, eventer *event.Eventer, ) (grpc.UnaryServerInterceptor, error) { - const op = "controller.requestCtxInterceptor" + const op = "controller.requestCtxUnaryInterceptor" + if err := sharedRequestInterceptorValidation( + ctx, + op, + iamRepoFn, + authTokenRepoFn, + serversRepoFn, + kms, + ticket, + eventer, + ); err != nil { + return nil, err + } + + return func(interceptorCtx context.Context, + req any, + _ *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, + ) (any, error) { + updatedCtx, err := sharedRequestInterceptorLogic( + interceptorCtx, + op, + iamRepoFn, + authTokenRepoFn, + serversRepoFn, + passwordAuthRepoFn, + oidcAuthRepoFn, + ldapAuthRepoFn, + kms, + ticket, + eventer, + ) + if err != nil { + return nil, err + } + return handler(updatedCtx, req) + }, nil +} + +// requestCtxStreamInterceptor creates a stream server interceptor that pulls +// grpc metadata into a ctx for the request. The metadata must be set in an +// upstream http handler/middleware by marshalling a RequestInfo protobuf into +// the requestInfoMdKey header (see: controller.wrapHandlerWithCommonFuncs). +func requestCtxStreamInterceptor( + ctx context.Context, + iamRepoFn common.IamRepoFactory, + authTokenRepoFn common.AuthTokenRepoFactory, + serversRepoFn common.ServersRepoFactory, + passwordAuthRepoFn common.PasswordAuthRepoFactory, + oidcAuthRepoFn common.OidcAuthRepoFactory, + ldapAuthRepoFn common.LdapAuthRepoFactory, + kms *kms.Kms, + ticket string, + eventer *event.Eventer, +) (grpc.StreamServerInterceptor, error) { + const op = "controller.requestCtxStreamInterceptor" + if err := sharedRequestInterceptorValidation( + ctx, + op, + iamRepoFn, + authTokenRepoFn, + serversRepoFn, + kms, + ticket, + eventer, + ); err != nil { + return nil, err + } + return func(srv any, + ss grpc.ServerStream, + info *grpc.StreamServerInfo, + handler grpc.StreamHandler, + ) error { + updatedCtx, err := sharedRequestInterceptorLogic( + ss.Context(), + op, + iamRepoFn, + authTokenRepoFn, + serversRepoFn, + passwordAuthRepoFn, + oidcAuthRepoFn, + ldapAuthRepoFn, + kms, + ticket, + eventer, + ) + if err != nil { + return err + } + css := &customContextServerStream{ + ServerStream: ss, + customContext: updatedCtx, + } + return handler(srv, css) + }, nil +} + +func sharedRequestInterceptorValidation( + ctx context.Context, + op errors.Op, + iamRepoFn common.IamRepoFactory, + authTokenRepoFn common.AuthTokenRepoFactory, + serversRepoFn common.ServersRepoFactory, + kms *kms.Kms, + ticket string, + eventer *event.Eventer, +) error { if iamRepoFn == nil { - return nil, errors.New(ctx, errors.InvalidParameter, op, "missing iam repo function") + return errors.New(ctx, errors.InvalidParameter, op, "missing iam repo function") } if authTokenRepoFn == nil { - return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth token repo function") + return errors.New(ctx, errors.InvalidParameter, op, "missing auth token repo function") } if serversRepoFn == nil { - return nil, errors.New(ctx, errors.InvalidParameter, op, "missing server repo function") + return errors.New(ctx, errors.InvalidParameter, op, "missing server repo function") } if kms == nil { - return nil, errors.New(ctx, errors.InvalidParameter, op, "missing kms") + return errors.New(ctx, errors.InvalidParameter, op, "missing kms") } if ticket == "" { - return nil, errors.New(ctx, errors.InvalidParameter, op, "missing ticket") + return errors.New(ctx, errors.InvalidParameter, op, "missing ticket") } if eventer == nil { - return nil, errors.New(ctx, errors.InvalidParameter, op, "missing eventer") + return errors.New(ctx, errors.InvalidParameter, op, "missing eventer") } - // Authorization unary interceptor function to handle authorize per RPC call - return func(interceptorCtx context.Context, - req any, - _ *grpc.UnaryServerInfo, - handler grpc.UnaryHandler, - ) (any, error) { - md, ok := metadata.FromIncomingContext(interceptorCtx) - if !ok { - return nil, errors.New(interceptorCtx, errors.Internal, op, "No metadata") - } - values := md.Get(requestInfoMdKey) - if len(values) == 0 { - return nil, errors.New(interceptorCtx, errors.Internal, op, "Missing request metadata") - } - if len(values) > 1 { - return nil, errors.New(interceptorCtx, errors.Internal, op, fmt.Sprintf("expected 1 value for %s metadata and got %d", requestInfoMdKey, len(values))) - } + return nil +} - decoded, err := base58.FastBase58Decoding(values[0]) - if err != nil { - return nil, errors.Wrap(interceptorCtx, err, op, errors.WithCode(errors.Internal), errors.WithMsg("unable to decode request info")) - } - var requestInfo authpb.RequestInfo - if err := proto.Unmarshal(decoded, &requestInfo); err != nil { - return nil, errors.Wrap(interceptorCtx, err, op, errors.WithCode(errors.Internal), errors.WithMsg("unable to unmarshal request info")) - } - switch { - case requestInfo.Ticket == "": - return nil, errors.New(interceptorCtx, errors.Internal, op, "Invalid context (missing ticket)") - case requestInfo.Ticket != ticket: - return nil, errors.New(interceptorCtx, errors.Internal, op, "Invalid context (bad ticket)") - } +func sharedRequestInterceptorLogic( + interceptorCtx context.Context, + op errors.Op, + iamRepoFn common.IamRepoFactory, + authTokenRepoFn common.AuthTokenRepoFactory, + serversRepoFn common.ServersRepoFactory, + passwordAuthRepoFn common.PasswordAuthRepoFactory, + oidcAuthRepoFn common.OidcAuthRepoFactory, + ldapAuthRepoFn common.LdapAuthRepoFactory, + kms *kms.Kms, + ticket string, + eventer *event.Eventer, +) (context.Context, error) { + md, ok := metadata.FromIncomingContext(interceptorCtx) + if !ok { + return nil, errors.New(interceptorCtx, errors.Internal, op, "No metadata") + } - interceptorCtx = auth.NewVerifierContextWithAccounts(interceptorCtx, iamRepoFn, authTokenRepoFn, serversRepoFn, passwordAuthRepoFn, oidcAuthRepoFn, ldapAuthRepoFn, kms, &requestInfo) - - // Add general request information to the context. The information from - // the auth verifier context is pretty specifically curated to - // authentication/authorization verification so this is more - // general-purpose. - // - // We could use requests.NewRequestContext but this saves an immediate - // lookup. - interceptorCtx = context.WithValue(interceptorCtx, requests.ContextRequestInformationKey, &requests.RequestContext{ - Path: requestInfo.Path, - Method: requestInfo.Method, - }) - - // This event request info is required by downstream handlers - info := &event.RequestInfo{ - EventId: requestInfo.EventId, - Id: requestInfo.TraceId, - PublicId: requestInfo.PublicId, - Method: requestInfo.Method, - Path: requestInfo.Path, - ClientIp: requestInfo.ClientIp, - } - interceptorCtx, err = event.NewRequestInfoContext(interceptorCtx, info) - if err != nil { - return nil, errors.Wrap(interceptorCtx, err, op, errors.WithCode(errors.Internal), errors.WithMsg("unable to create context with request info")) - } - interceptorCtx, err = event.NewEventerContext(interceptorCtx, eventer) - if err != nil { - return nil, errors.Wrap(interceptorCtx, err, op, errors.WithCode(errors.Internal), errors.WithMsg("unable to create context with eventer")) - } + values := md.Get(requestInfoMdKey) + if len(values) == 0 { + return nil, errors.New(interceptorCtx, errors.Internal, op, "Missing request metadata") + } + if len(values) > 1 { + return nil, errors.New(interceptorCtx, errors.Internal, op, fmt.Sprintf("expected 1 value for %s metadata and got %d", requestInfoMdKey, len(values))) + } + + decoded, err := base58.FastBase58Decoding(values[0]) + if err != nil { + return nil, errors.Wrap(interceptorCtx, err, op, errors.WithCode(errors.Internal), errors.WithMsg("unable to decode request info")) + } + var requestInfo authpb.RequestInfo + if err := proto.Unmarshal(decoded, &requestInfo); err != nil { + return nil, errors.Wrap(interceptorCtx, err, op, errors.WithCode(errors.Internal), errors.WithMsg("unable to unmarshal request info")) + } + switch { + case requestInfo.Ticket == "": + return nil, errors.New(interceptorCtx, errors.Internal, op, "Invalid context (missing ticket)") + case requestInfo.Ticket != ticket: + return nil, errors.New(interceptorCtx, errors.Internal, op, "Invalid context (bad ticket)") + } - // Calls the handler - h, err := handler(interceptorCtx, req) + interceptorCtx = auth.NewVerifierContextWithAccounts(interceptorCtx, iamRepoFn, authTokenRepoFn, serversRepoFn, passwordAuthRepoFn, oidcAuthRepoFn, ldapAuthRepoFn, kms, &requestInfo) - return h, err // not convinced we want to wrap every error and turn them into domain errors... - }, nil + // Add general request information to the context. The information from + // the auth verifier context is pretty specifically curated to + // authentication/authorization verification so this is more + // general-purpose. + // + // We could use requests.NewRequestContext but this saves an immediate + // lookup. + interceptorCtx = context.WithValue(interceptorCtx, requests.ContextRequestInformationKey, &requests.RequestContext{ + Path: requestInfo.Path, + Method: requestInfo.Method, + }) + + // This event request info is required by downstream handlers + info := &event.RequestInfo{ + EventId: requestInfo.EventId, + Id: requestInfo.TraceId, + PublicId: requestInfo.PublicId, + Method: requestInfo.Method, + Path: requestInfo.Path, + ClientIp: requestInfo.ClientIp, + } + interceptorCtx, err = event.NewRequestInfoContext(interceptorCtx, info) + if err != nil { + return nil, errors.Wrap(interceptorCtx, err, op, errors.WithCode(errors.Internal), errors.WithMsg("unable to create context with request info")) + } + interceptorCtx, err = event.NewEventerContext(interceptorCtx, eventer) + if err != nil { + return nil, errors.Wrap(interceptorCtx, err, op, errors.WithCode(errors.Internal), errors.WithMsg("unable to create context with eventer")) + } + + return interceptorCtx, err // not convinced we want to wrap every error and turn them into domain errors... } func errorInterceptor( diff --git a/internal/daemon/controller/interceptor_test.go b/internal/daemon/controller/interceptor_test.go index 250fb8170b..d743cb26cc 100644 --- a/internal/daemon/controller/interceptor_test.go +++ b/internal/daemon/controller/interceptor_test.go @@ -30,6 +30,7 @@ import ( "github.com/mr-tron/base58" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/genproto/googleapis/api/httpbody" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" @@ -314,7 +315,7 @@ func Test_requestCtxInterceptor(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) - interceptor, err := requestCtxInterceptor(factoryCtx, tt.iamRepoFn, tt.authTokenRepoFn, tt.serversRepoFn, nil, nil, nil, tt.kms, tt.ticket, tt.eventer) + interceptor, err := requestCtxUnaryInterceptor(factoryCtx, tt.iamRepoFn, tt.authTokenRepoFn, tt.serversRepoFn, nil, nil, nil, tt.kms, tt.ticket, tt.eventer) if tt.wantFactoryErr { require.Error(err) assert.Nil(interceptor) @@ -351,6 +352,310 @@ func Test_requestCtxInterceptor(t *testing.T) { } } +func Test_streamCtxInterceptor(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + kmsCache := kms.TestKms(t, conn, wrapper) + iamRepo := iam.TestRepo(t, conn, wrapper) + + iamRepoFn := func() (*iam.Repository, error) { + return iamRepo, nil + } + atRepoFn := func() (*authtoken.Repository, error) { + return authtoken.NewRepository(rw, rw, kmsCache) + } + serversRepoFn := func() (*server.Repository, error) { + return server.NewRepository(rw, rw, kmsCache) + } + + validGatewayTicket := "valid-ticket" + + o, _ := iam.TestScopes(t, iamRepo) + at := authtoken.TestAuthToken(t, conn, kmsCache, o.GetPublicId()) + encToken, err := authtoken.EncryptToken(context.Background(), kmsCache, o.GetPublicId(), at.GetPublicId(), at.GetToken()) + require.NoError(t, err) + tokValue := at.GetPublicId() + "_" + encToken + + newReqCtx := func(gwTicket string) context.Context { + req := httptest.NewRequest("GET", "http://127.0.0.1/v1/scopes/o_1", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tokValue)) + // Add values for authn/authz checking + requestInfo := authpb.RequestInfo{ + Path: req.URL.Path, + Method: req.Method, + EventId: "test-event-id", + TraceId: "test-trace-id", + } + requestInfo.PublicId, requestInfo.EncryptedToken, requestInfo.TokenFormat = auth.GetTokenFromRequest(context.TODO(), kmsCache, req) + requestInfo.Ticket = gwTicket // allows the grpc-gateway to verify the request info came from it's in-memory companion http proxy + marshalledRequestInfo, err := proto.Marshal(&requestInfo) + require.NoError(t, err) + md := metadata.Pairs(requestInfoMdKey, base58.FastBase58Encoding(marshalledRequestInfo)) + mdCtx := metadata.NewIncomingContext(context.Background(), md) + + md, ok := metadata.FromIncomingContext(mdCtx) + require.True(t, ok) + require.NotNil(t, md) + + return mdCtx + } + + factoryCtx := context.Background() + + c := event.TestEventerConfig(t, "Test_requestCtxInterceptor", event.TestWithAuditSink(t), event.TestWithObservationSink(t)) + testLock := &sync.Mutex{} + testLogger := hclog.New(&hclog.LoggerOptions{ + Mutex: testLock, + Name: "test", + }) + testEventer, err := event.NewEventer(testLogger, testLock, "Test_requestCtxInterceptor", c.EventerConfig) + require.NoError(t, err) + tests := []struct { + name string + requestCtx context.Context + iamRepoFn common.IamRepoFactory + authTokenRepoFn common.AuthTokenRepoFactory + serversRepoFn common.ServersRepoFactory + kms *kms.Kms + eventer *event.Eventer + ticket string + wantFactoryErr bool + wantFactoryErrMatch *errors.Template + wantFactoryErrContains string + wantRequestErr bool + wantRequestErrMatch *errors.Template + wantRequestErrContains string + }{ + { + name: "missing-iam-repo", + requestCtx: newReqCtx(validGatewayTicket), + authTokenRepoFn: atRepoFn, + serversRepoFn: serversRepoFn, + kms: kmsCache, + eventer: testEventer, + wantFactoryErr: true, + wantFactoryErrMatch: errors.T(errors.InvalidParameter), + wantFactoryErrContains: "missing iam repo", + }, + { + name: "missing-at-repo", + requestCtx: newReqCtx(validGatewayTicket), + iamRepoFn: iamRepoFn, + serversRepoFn: serversRepoFn, + kms: kmsCache, + eventer: testEventer, + wantFactoryErr: true, + wantFactoryErrMatch: errors.T(errors.InvalidParameter), + wantFactoryErrContains: "missing auth token repo", + }, + { + name: "missing-servers-repo", + requestCtx: newReqCtx(validGatewayTicket), + iamRepoFn: iamRepoFn, + authTokenRepoFn: atRepoFn, + kms: kmsCache, + eventer: testEventer, + wantFactoryErr: true, + wantFactoryErrMatch: errors.T(errors.InvalidParameter), + wantFactoryErrContains: "missing server repo function", + }, + { + name: "missing-kms", + requestCtx: newReqCtx(validGatewayTicket), + iamRepoFn: iamRepoFn, + authTokenRepoFn: atRepoFn, + serversRepoFn: serversRepoFn, + eventer: testEventer, + wantFactoryErr: true, + wantFactoryErrMatch: errors.T(errors.InvalidParameter), + wantFactoryErrContains: "missing kms", + }, + { + name: "missing-eventer", + requestCtx: newReqCtx(validGatewayTicket), + iamRepoFn: iamRepoFn, + authTokenRepoFn: atRepoFn, + serversRepoFn: serversRepoFn, + wantFactoryErr: true, + wantFactoryErrMatch: errors.T(errors.InvalidParameter), + wantFactoryErrContains: "missing kms", + }, + { + name: "missing-factory-ticket", + requestCtx: newReqCtx(validGatewayTicket), + iamRepoFn: iamRepoFn, + authTokenRepoFn: atRepoFn, + serversRepoFn: serversRepoFn, + kms: kmsCache, + eventer: testEventer, + wantFactoryErr: true, + wantFactoryErrMatch: errors.T(errors.InvalidParameter), + wantFactoryErrContains: "missing ticket", + }, + { + name: "missing-metadata", + requestCtx: context.Background(), + iamRepoFn: iamRepoFn, + authTokenRepoFn: atRepoFn, + serversRepoFn: serversRepoFn, + kms: kmsCache, + eventer: testEventer, + ticket: validGatewayTicket, + wantRequestErr: true, + wantRequestErrMatch: errors.T(errors.Internal), + wantRequestErrContains: "No metadata", + }, + { + name: "too-many-request-info-metadata", + requestCtx: func() context.Context { + md := metadata.Pairs(requestInfoMdKey, "first", requestInfoMdKey, "second") + return metadata.NewIncomingContext(context.Background(), md) + }(), + iamRepoFn: iamRepoFn, + authTokenRepoFn: atRepoFn, + serversRepoFn: serversRepoFn, + kms: kmsCache, + eventer: testEventer, + ticket: validGatewayTicket, + wantRequestErr: true, + wantRequestErrMatch: errors.T(errors.Internal), + wantRequestErrContains: "expected 1 value", + }, + { + name: "request-info-metadata-not-encoded", + requestCtx: func() context.Context { + md := metadata.Pairs(requestInfoMdKey, "hello") + return metadata.NewIncomingContext(context.Background(), md) + }(), + iamRepoFn: iamRepoFn, + authTokenRepoFn: atRepoFn, + serversRepoFn: serversRepoFn, + kms: kmsCache, + eventer: testEventer, + ticket: validGatewayTicket, + wantRequestErr: true, + wantRequestErrMatch: errors.T(errors.Internal), + wantRequestErrContains: "unable to decode request info", + }, + { + name: "request-info-metadata-not-proto", + requestCtx: func() context.Context { + md := metadata.Pairs(requestInfoMdKey, base58.FastBase58Encoding([]byte("hello"))) + return metadata.NewIncomingContext(context.Background(), md) + }(), + iamRepoFn: iamRepoFn, + authTokenRepoFn: atRepoFn, + serversRepoFn: serversRepoFn, + kms: kmsCache, + eventer: testEventer, + ticket: validGatewayTicket, + wantRequestErr: true, + wantRequestErrMatch: errors.T(errors.Internal), + wantRequestErrContains: "unable to unmarshal request info", + }, + { + name: "missing-request-ticket", + requestCtx: newReqCtx(""), + iamRepoFn: iamRepoFn, + authTokenRepoFn: atRepoFn, + serversRepoFn: serversRepoFn, + kms: kmsCache, + eventer: testEventer, + ticket: "validGatewayTicket", + wantRequestErr: true, + wantRequestErrMatch: errors.T(errors.Internal), + wantRequestErrContains: "Invalid context (missing ticket)", + }, + { + name: "bad-ticket", + requestCtx: newReqCtx("bad-ticket"), + iamRepoFn: iamRepoFn, + authTokenRepoFn: atRepoFn, + serversRepoFn: serversRepoFn, + kms: kmsCache, + eventer: testEventer, + ticket: "validGatewayTicket", + wantRequestErr: true, + wantRequestErrMatch: errors.T(errors.Internal), + wantRequestErrContains: "Invalid context (bad ticket)", + }, + { + name: "valid", + requestCtx: newReqCtx(validGatewayTicket), + iamRepoFn: iamRepoFn, + authTokenRepoFn: atRepoFn, + serversRepoFn: serversRepoFn, + kms: kmsCache, + eventer: testEventer, + ticket: validGatewayTicket, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + interceptor, err := requestCtxStreamInterceptor(factoryCtx, tt.iamRepoFn, tt.authTokenRepoFn, tt.serversRepoFn, nil, nil, nil, tt.kms, tt.ticket, tt.eventer) + if tt.wantFactoryErr { + require.Error(err) + assert.Nil(interceptor) + if tt.wantFactoryErrMatch != nil { + assert.Truef(errors.Match(tt.wantFactoryErrMatch, err), "want err code: %q got: %q", tt.wantFactoryErrMatch.Code, err) + } + if tt.wantFactoryErrContains != "" { + assert.Contains(err.Error(), tt.wantFactoryErrContains) + } + return + } + require.NoError(err) + assert.NotNil(interceptor) + + info := &grpc.StreamServerInfo{ + FullMethod: "FakeMethod", + IsClientStream: true, + } + var hdCtx context.Context + + hd := func(srv interface{}, stream grpc.ServerStream) error { + hdCtx = stream.Context() + return nil + } + m := &streamMock{ctx: tt.requestCtx} + err = interceptor(nil, m, info, hd) + if tt.wantRequestErr { + require.Error(err) + if tt.wantRequestErrMatch != nil { + assert.Truef(errors.Match(tt.wantRequestErrMatch, err), "want err code: %q got: %q", tt.wantRequestErrMatch.Code, err) + } + if tt.wantRequestErrContains != "" { + assert.Contains(err.Error(), tt.wantRequestErrContains) + } + return + } + require.NoError(err) + verifyResults := auth.Verify(hdCtx.(context.Context)) + assert.NotEmpty(verifyResults) + }) + } +} + +type streamMock struct { + grpc.ServerStream + ctx context.Context +} + +func (m *streamMock) Context() context.Context { + return m.ctx +} + +func (m *streamMock) Send(req *httpbody.HttpBody) error { + panic("send not implemented") +} + +func (m *streamMock) RecvToClient() (*httpbody.HttpBody, error) { + panic("recv not implemented") +} + func Test_errorInterceptor(t *testing.T) { ctx := context.Background() tests := []struct {