From 1fd44ee0f1aaf35d428c746546e6f97f66cf3af2 Mon Sep 17 00:00:00 2001 From: Johan Brandhorst-Satzkorn Date: Fri, 8 Apr 2022 15:59:18 -0400 Subject: [PATCH] refact(subtypes): Remove return from transformations This return value is just a reference to the input, which is being mutated. We may as well make that clear via the signature, which will also simplify the new custom transformation function registration logic. --- internal/types/subtypes/interceptor.go | 37 ++++++++------------- internal/types/subtypes/interceptor_test.go | 8 ++--- 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/internal/types/subtypes/interceptor.go b/internal/types/subtypes/interceptor.go index 9fe7e8baef..77656f2032 100644 --- a/internal/types/subtypes/interceptor.go +++ b/internal/types/subtypes/interceptor.go @@ -93,7 +93,7 @@ func messageDomain(m proto.Message) string { // // Also note that for any of the id based lookups to function, the file that contains // the proto.Message definition must set the "domain" custom option. -func transformRequestAttributes(req proto.Message) (proto.Message, error) { +func transformRequestAttributes(req proto.Message) error { domain := messageDomain(req) r := req.ProtoReflect() @@ -115,7 +115,7 @@ func transformRequestAttributes(req proto.Message) (proto.Message, error) { case itemField != nil: itemR := itemField.Message() if itemR == nil { - return req, nil + return nil } id := fieldValue(r, idField) @@ -136,19 +136,15 @@ func transformRequestAttributes(req proto.Message) (proto.Message, error) { case typeField != nil && t != "": st = Subtype(t) default: // need either type or id - return req, nil - } - if err := convertAttributesToSubtype(item, st); err != nil { - return req, err + return nil } + return convertAttributesToSubtype(item, st) case idField != nil && attributesField != nil: id := r.Get(idField).String() st = SubtypeFromId(domain, id) - if err := convertAttributesToSubtype(req, st); err != nil { - return req, err - } + return convertAttributesToSubtype(req, st) } - return req, nil + return nil } func transformResponseItemAttributes(item proto.Message) error { @@ -206,7 +202,7 @@ func transformResponseItemAttributes(item proto.Message) error { // // other subtype attributes types // } // } -func transformResponseAttributes(res proto.Message) (proto.Message, error) { +func transformResponseAttributes(res proto.Message) error { r := res.ProtoReflect() fields := r.Descriptor().Fields() @@ -215,27 +211,25 @@ func transformResponseAttributes(res proto.Message) (proto.Message, error) { switch { case itemField != nil: if itemR := itemField.Message(); itemR == nil { - return res, nil + return nil } item := r.Get(itemField).Message().Interface() - if err := transformResponseItemAttributes(item); err != nil { - return res, err - } + return transformResponseItemAttributes(item) case itemsField != nil: if !itemsField.IsList() { - return res, nil + return nil } items := r.Get(itemsField).List() for i := 0; i < items.Len(); i++ { item := items.Get(i).Message().Interface() if err := transformResponseItemAttributes(item); err != nil { - return res, err + return err } } } - return res, nil + return nil } // AttributeTransformerInterceptor is a grpc server interceptor that will @@ -285,10 +279,8 @@ func transformResponseAttributes(res proto.Message) (proto.Message, error) { func AttributeTransformerInterceptor(_ context.Context) grpc.UnaryServerInterceptor { const op = "subtypes.AttributeTransformInterceptor" return func(interceptorCtx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - var err error if reqMsg, ok := req.(proto.Message); ok { - req, err = transformRequestAttributes(reqMsg) - if err != nil { + if err := transformRequestAttributes(reqMsg); err != nil { return nil, handlers.InvalidArgumentErrorf("Error in provided request.", map[string]string{"attributes": "Attribute fields do not match the expected format."}) } @@ -297,8 +289,7 @@ func AttributeTransformerInterceptor(_ context.Context) grpc.UnaryServerIntercep res, handlerErr := handler(interceptorCtx, req) if res, ok := res.(proto.Message); ok { - res, err = transformResponseAttributes(res) - if err != nil { + if err := transformResponseAttributes(res); err != nil { return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "failed building attribute struct: %v", err) } } diff --git a/internal/types/subtypes/interceptor_test.go b/internal/types/subtypes/interceptor_test.go index 3a15853217..1e2c828f88 100644 --- a/internal/types/subtypes/interceptor_test.go +++ b/internal/types/subtypes/interceptor_test.go @@ -298,9 +298,9 @@ func TestTransformRequestAttributes(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - got, err := transformRequestAttributes(tc.req) + err := transformRequest(tc.req) require.NoError(t, err) - assert.Empty(t, cmp.Diff(got, tc.expected, protocmp.Transform())) + assert.Empty(t, cmp.Diff(tc.req, tc.expected, protocmp.Transform())) }) } } @@ -582,9 +582,9 @@ func TestTransformResponseAttributes(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - got, err := transformResponseAttributes(tc.resp) + err := transformResponse(tc.resp) require.NoError(t, err) - assert.Empty(t, cmp.Diff(got, tc.expected, protocmp.Transform())) + assert.Empty(t, cmp.Diff(tc.resp, tc.expected, protocmp.Transform())) }) } }