refact(subtypes): Remove return from transformations

This return value is just a reference to the input, which
is being mutated. We may as well make that clear via
the signature, which will also simplify the new custom
transformation function registration logic.
pull/2031/head
Johan Brandhorst-Satzkorn 4 years ago
parent 53481146fd
commit 1fd44ee0f1

@ -93,7 +93,7 @@ func messageDomain(m proto.Message) string {
//
// Also note that for any of the id based lookups to function, the file that contains
// the proto.Message definition must set the "domain" custom option.
func transformRequestAttributes(req proto.Message) (proto.Message, error) {
func transformRequestAttributes(req proto.Message) error {
domain := messageDomain(req)
r := req.ProtoReflect()
@ -115,7 +115,7 @@ func transformRequestAttributes(req proto.Message) (proto.Message, error) {
case itemField != nil:
itemR := itemField.Message()
if itemR == nil {
return req, nil
return nil
}
id := fieldValue(r, idField)
@ -136,19 +136,15 @@ func transformRequestAttributes(req proto.Message) (proto.Message, error) {
case typeField != nil && t != "":
st = Subtype(t)
default: // need either type or id
return req, nil
}
if err := convertAttributesToSubtype(item, st); err != nil {
return req, err
return nil
}
return convertAttributesToSubtype(item, st)
case idField != nil && attributesField != nil:
id := r.Get(idField).String()
st = SubtypeFromId(domain, id)
if err := convertAttributesToSubtype(req, st); err != nil {
return req, err
}
return convertAttributesToSubtype(req, st)
}
return req, nil
return nil
}
func transformResponseItemAttributes(item proto.Message) error {
@ -206,7 +202,7 @@ func transformResponseItemAttributes(item proto.Message) error {
// // other subtype attributes types
// }
// }
func transformResponseAttributes(res proto.Message) (proto.Message, error) {
func transformResponseAttributes(res proto.Message) error {
r := res.ProtoReflect()
fields := r.Descriptor().Fields()
@ -215,27 +211,25 @@ func transformResponseAttributes(res proto.Message) (proto.Message, error) {
switch {
case itemField != nil:
if itemR := itemField.Message(); itemR == nil {
return res, nil
return nil
}
item := r.Get(itemField).Message().Interface()
if err := transformResponseItemAttributes(item); err != nil {
return res, err
}
return transformResponseItemAttributes(item)
case itemsField != nil:
if !itemsField.IsList() {
return res, nil
return nil
}
items := r.Get(itemsField).List()
for i := 0; i < items.Len(); i++ {
item := items.Get(i).Message().Interface()
if err := transformResponseItemAttributes(item); err != nil {
return res, err
return err
}
}
}
return res, nil
return nil
}
// AttributeTransformerInterceptor is a grpc server interceptor that will
@ -285,10 +279,8 @@ func transformResponseAttributes(res proto.Message) (proto.Message, error) {
func AttributeTransformerInterceptor(_ context.Context) grpc.UnaryServerInterceptor {
const op = "subtypes.AttributeTransformInterceptor"
return func(interceptorCtx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var err error
if reqMsg, ok := req.(proto.Message); ok {
req, err = transformRequestAttributes(reqMsg)
if err != nil {
if err := transformRequestAttributes(reqMsg); err != nil {
return nil, handlers.InvalidArgumentErrorf("Error in provided request.",
map[string]string{"attributes": "Attribute fields do not match the expected format."})
}
@ -297,8 +289,7 @@ func AttributeTransformerInterceptor(_ context.Context) grpc.UnaryServerIntercep
res, handlerErr := handler(interceptorCtx, req)
if res, ok := res.(proto.Message); ok {
res, err = transformResponseAttributes(res)
if err != nil {
if err := transformResponseAttributes(res); err != nil {
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "failed building attribute struct: %v", err)
}
}

@ -298,9 +298,9 @@ func TestTransformRequestAttributes(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, err := transformRequestAttributes(tc.req)
err := transformRequest(tc.req)
require.NoError(t, err)
assert.Empty(t, cmp.Diff(got, tc.expected, protocmp.Transform()))
assert.Empty(t, cmp.Diff(tc.req, tc.expected, protocmp.Transform()))
})
}
}
@ -582,9 +582,9 @@ func TestTransformResponseAttributes(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, err := transformResponseAttributes(tc.resp)
err := transformResponse(tc.resp)
require.NoError(t, err)
assert.Empty(t, cmp.Diff(got, tc.expected, protocmp.Transform()))
assert.Empty(t, cmp.Diff(tc.resp, tc.expected, protocmp.Transform()))
})
}
}

Loading…
Cancel
Save