From 86f2535aaf37e7d0338acd5a356aa0ca10461b64 Mon Sep 17 00:00:00 2001 From: Chris Marchesi Date: Mon, 25 Oct 2021 10:36:58 -0700 Subject: [PATCH] internal/libs/patchstruct: add package for patching structpb (#1592) * internal/libs/patchstruct: add package for patching structpb This adds a small helper package for patching *structpb.Struct values using the same patterns that our APIs use for updating records. The intent of this package is to provide a fallback mechanism for working with these values when said record update behavior is unavailable, such as when working with plugin-based host subtypes where the exact structure of the subtype's attributes is unavailable. patchstruct.Patch follows the following logic: * The source (src) map is merged into the destination map (dst). * Values are overwritten by the source map if they exist in both. * Values are deleted from the destination if they are set to null in the source. * Maps are recursively applied, meaning that a nested map at key "foo" in the destination would be patched with a map at key "foo" in the source. * A map in the destination is overwritten by a non-map in the source, and a non-map in the destination is overwritten by a map in the source. Patch returns the updated map as a copy, dst and src are not altered. * Add doc comment for patchstruct.Patch * vim snafu * Premature commit of repository_host_catalog.go * add nil cases * internal/libs/patchstruct: add PatchJSON * small comment formatting fix * Change PatchJSON to PatchBytes PatchJSON wasn't what was needed here - we need a helper to patch the protobuf encoding of structpb.Struct. * remove equality for original src/dst for patchbytes --- internal/libs/patchstruct/patchstruct.go | 90 ++++++++ internal/libs/patchstruct/patchstruct_test.go | 198 ++++++++++++++++++ 2 files changed, 288 insertions(+) create mode 100644 internal/libs/patchstruct/patchstruct.go create mode 100644 internal/libs/patchstruct/patchstruct_test.go diff --git a/internal/libs/patchstruct/patchstruct.go b/internal/libs/patchstruct/patchstruct.go new file mode 100644 index 0000000000..74798006df --- /dev/null +++ b/internal/libs/patchstruct/patchstruct.go @@ -0,0 +1,90 @@ +package patchstruct + +import ( + "fmt" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" +) + +// PatchStruct updates the struct found in dst with the values found in src. +// The intent of this helper is to provide a fallback mechanism for subtype +// attributes when the actual schema of the subtype attributes are unknown. As +// such, it's preferred to use other methods (such as mask mapping) when an +// actual message for the subtype is known. +// +// The following rules apply: +// +// * The source (src) map is merged into the destination map (dst). +// +// * Values are overwritten by the source map if they exist in both. +// +// * Values are deleted from the destination if they are set to null in the +// source. +// +// * Maps are recursively applied, meaning that a nested map at key "foo" in +// the destination would be patched with a map at key "foo" in the +// source. +// +// * A map in the destination is overwritten by a non-map in the source, +// and a non-map in the destination is overwritten by a map in the +// source. +// +// PatchStruct returns the updated map as a copy, dst and src are not altered. +func PatchStruct(dst, src *structpb.Struct) *structpb.Struct { + result, err := structpb.NewStruct(patchM(dst.AsMap(), src.AsMap())) + if err != nil { + // Should never error as values are source from structpb values + panic(err) + } + + return result +} + +// PatchBytes follows the same rules as above with PatchStruct, but +// instead of patching structpb.Structs, it patches the protobuf +// encoding. An error is returned if there are issues working with +// the data. +func PatchBytes(dst, src []byte) ([]byte, error) { + srcpb, dstpb := new(structpb.Struct), new(structpb.Struct) + if err := proto.Unmarshal(dst, dstpb); err != nil { + return nil, fmt.Errorf("error reading destination data: %w", err) + } + + if err := proto.Unmarshal(src, srcpb); err != nil { + return nil, fmt.Errorf("error reading source data: %w", err) + } + + result, err := proto.Marshal(PatchStruct(dstpb, srcpb)) + if err != nil { + return nil, fmt.Errorf("error writing result data: %w", err) + } + + return result, nil +} + +func patchM(dst, src map[string]interface{}) map[string]interface{} { + for k, v := range src { + switch x := v.(type) { + case map[string]interface{}: + if y, ok := dst[k].(map[string]interface{}); ok { + // If the value in dst a map, continue to patch + dst[k] = patchM(y, x) + } else { + // Overwrite + dst[k] = x + } + + default: + if v == nil { + // explicit null values delete values at that key in dst + delete(dst, k) + } else { + // Anything else gets overwritten + dst[k] = v + } + } + } + + return dst +} diff --git a/internal/libs/patchstruct/patchstruct_test.go b/internal/libs/patchstruct/patchstruct_test.go new file mode 100644 index 0000000000..c043be4135 --- /dev/null +++ b/internal/libs/patchstruct/patchstruct_test.go @@ -0,0 +1,198 @@ +package patchstruct + +import ( + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" +) + +type testCase struct { + name string + dst map[string]interface{} + src map[string]interface{} + expected map[string]interface{} +} + +var testCases = []testCase{ + { + name: "merge", + dst: map[string]interface{}{ + "foo": "bar", + }, + src: map[string]interface{}{ + "baz": "qux", + }, + expected: map[string]interface{}{ + "foo": "bar", + "baz": "qux", + }, + }, + { + name: "overwrite", + dst: map[string]interface{}{ + "foo": "bar", + }, + src: map[string]interface{}{ + "foo": "baz", + }, + expected: map[string]interface{}{ + "foo": "baz", + }, + }, + { + name: "delete", + dst: map[string]interface{}{ + "foo": "bar", + "baz": "qux", + }, + src: map[string]interface{}{ + "baz": nil, + }, + expected: map[string]interface{}{ + "foo": "bar", + }, + }, + { + name: "recursive", + dst: map[string]interface{}{ + "nested": map[string]interface{}{ + "a": "b", + }, + "foo": "bar", + }, + src: map[string]interface{}{ + "nested": map[string]interface{}{ + "c": "d", + }, + }, + expected: map[string]interface{}{ + "nested": map[string]interface{}{ + "a": "b", + "c": "d", + }, + "foo": "bar", + }, + }, + { + name: "overwrite with map in src", + dst: map[string]interface{}{ + "foo": "bar", + }, + src: map[string]interface{}{ + "foo": map[string]interface{}{ + "a": "b", + }, + }, + expected: map[string]interface{}{ + "foo": map[string]interface{}{ + "a": "b", + }, + }, + }, + { + name: "nil src", + dst: map[string]interface{}{ + "foo": "bar", + }, + src: nil, + expected: map[string]interface{}{ + "foo": "bar", + }, + }, + { + name: "nil dst", + dst: nil, + src: map[string]interface{}{ + "foo": "bar", + }, + expected: map[string]interface{}{ + "foo": "bar", + }, + }, +} + +func TestPatchStruct(t *testing.T) { + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + dst, src := mustStruct(tc.dst), mustStruct(tc.src) + dstOrig, srcOrig := mustStruct(tc.dst), mustStruct(tc.src) + if tc.dst == nil { + dst = nil + dstOrig = nil + } + if tc.src == nil { + src = nil + srcOrig = nil + } + + actual := PatchStruct(dst, src) + require.Equal(mustStruct(tc.expected), actual) + require.Equal(dstOrig, dst) + require.Equal(srcOrig, src) + }) + } +} + +func TestPatchBytes(t *testing.T) { + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + dst, src := mustMarshal(tc.dst), mustMarshal(tc.src) + + actual, err := PatchBytes(dst, src) + require.NoError(err) + requireEqualEncoded(t, mustMarshal(tc.expected), actual) + }) + } +} + +func TestPatchBytesErr(t *testing.T) { + t.Run("dst", func(t *testing.T) { + require := require.New(t) + _, err := PatchBytes([]byte("foo"), nil) + require.EqualError(err, "error reading destination data: proto: cannot parse invalid wire-format data") + }) + t.Run("src", func(t *testing.T) { + require := require.New(t) + _, err := PatchBytes(nil, []byte("foo")) + require.EqualError(err, "error reading source data: proto: cannot parse invalid wire-format data") + }) +} + +func mustStruct(in map[string]interface{}) *structpb.Struct { + out, err := structpb.NewStruct(in) + if err != nil { + panic(err) + } + + return out +} + +func mustMarshal(in map[string]interface{}) []byte { + b, err := proto.Marshal(mustStruct(in)) + if err != nil { + panic(err) + } + + return b +} + +func requireEqualEncoded(t *testing.T, expected, actual []byte) { + t.Helper() + require := require.New(t) + + expectedpb, actualpb := new(structpb.Struct), new(structpb.Struct) + + err := proto.Unmarshal(expected, expectedpb) + require.NoError(err) + + err = proto.Unmarshal(actual, actualpb) + require.NoError(err) + + require.Equal(expectedpb, actualpb) +}