diff --git a/internal/types/subtypes/interceptor.go b/internal/types/subtypes/interceptor.go index 9fe7e8baef..77656f2032 100644 --- a/internal/types/subtypes/interceptor.go +++ b/internal/types/subtypes/interceptor.go @@ -93,7 +93,7 @@ func messageDomain(m proto.Message) string { // // Also note that for any of the id based lookups to function, the file that contains // the proto.Message definition must set the "domain" custom option. -func transformRequestAttributes(req proto.Message) (proto.Message, error) { +func transformRequestAttributes(req proto.Message) error { domain := messageDomain(req) r := req.ProtoReflect() @@ -115,7 +115,7 @@ func transformRequestAttributes(req proto.Message) (proto.Message, error) { case itemField != nil: itemR := itemField.Message() if itemR == nil { - return req, nil + return nil } id := fieldValue(r, idField) @@ -136,19 +136,15 @@ func transformRequestAttributes(req proto.Message) (proto.Message, error) { case typeField != nil && t != "": st = Subtype(t) default: // need either type or id - return req, nil - } - if err := convertAttributesToSubtype(item, st); err != nil { - return req, err + return nil } + return convertAttributesToSubtype(item, st) case idField != nil && attributesField != nil: id := r.Get(idField).String() st = SubtypeFromId(domain, id) - if err := convertAttributesToSubtype(req, st); err != nil { - return req, err - } + return convertAttributesToSubtype(req, st) } - return req, nil + return nil } func transformResponseItemAttributes(item proto.Message) error { @@ -206,7 +202,7 @@ func transformResponseItemAttributes(item proto.Message) error { // // other subtype attributes types // } // } -func transformResponseAttributes(res proto.Message) (proto.Message, error) { +func transformResponseAttributes(res proto.Message) error { r := res.ProtoReflect() fields := r.Descriptor().Fields() @@ -215,27 +211,25 @@ func transformResponseAttributes(res proto.Message) (proto.Message, error) { switch { case itemField != nil: if itemR := itemField.Message(); itemR == nil { - return res, nil + return nil } item := r.Get(itemField).Message().Interface() - if err := transformResponseItemAttributes(item); err != nil { - return res, err - } + return transformResponseItemAttributes(item) case itemsField != nil: if !itemsField.IsList() { - return res, nil + return nil } items := r.Get(itemsField).List() for i := 0; i < items.Len(); i++ { item := items.Get(i).Message().Interface() if err := transformResponseItemAttributes(item); err != nil { - return res, err + return err } } } - return res, nil + return nil } // AttributeTransformerInterceptor is a grpc server interceptor that will @@ -285,10 +279,8 @@ func transformResponseAttributes(res proto.Message) (proto.Message, error) { func AttributeTransformerInterceptor(_ context.Context) grpc.UnaryServerInterceptor { const op = "subtypes.AttributeTransformInterceptor" return func(interceptorCtx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - var err error if reqMsg, ok := req.(proto.Message); ok { - req, err = transformRequestAttributes(reqMsg) - if err != nil { + if err := transformRequestAttributes(reqMsg); err != nil { return nil, handlers.InvalidArgumentErrorf("Error in provided request.", map[string]string{"attributes": "Attribute fields do not match the expected format."}) } @@ -297,8 +289,7 @@ func AttributeTransformerInterceptor(_ context.Context) grpc.UnaryServerIntercep res, handlerErr := handler(interceptorCtx, req) if res, ok := res.(proto.Message); ok { - res, err = transformResponseAttributes(res) - if err != nil { + if err := transformResponseAttributes(res); err != nil { return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "failed building attribute struct: %v", err) } } diff --git a/internal/types/subtypes/interceptor_test.go b/internal/types/subtypes/interceptor_test.go index 3a15853217..1e2c828f88 100644 --- a/internal/types/subtypes/interceptor_test.go +++ b/internal/types/subtypes/interceptor_test.go @@ -298,9 +298,9 @@ func TestTransformRequestAttributes(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - got, err := transformRequestAttributes(tc.req) + err := transformRequest(tc.req) require.NoError(t, err) - assert.Empty(t, cmp.Diff(got, tc.expected, protocmp.Transform())) + assert.Empty(t, cmp.Diff(tc.req, tc.expected, protocmp.Transform())) }) } } @@ -582,9 +582,9 @@ func TestTransformResponseAttributes(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - got, err := transformResponseAttributes(tc.resp) + err := transformResponse(tc.resp) require.NoError(t, err) - assert.Empty(t, cmp.Diff(got, tc.expected, protocmp.Transform())) + assert.Empty(t, cmp.Diff(tc.resp, tc.expected, protocmp.Transform())) }) } }