From 098248be18dc3533fae14518a4f3d8326d186647 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Wed, 4 Jan 2023 09:44:11 -0800 Subject: [PATCH] Add the ability for CLI map values to contain keys only, which map to JSON nulls (#2721) --- internal/cmd/base/flags.go | 38 +++-- .../credentiallibrariescmd/vault_funcs.go | 8 +- internal/cmd/common/flags.go | 63 +++++--- internal/cmd/common/flags_test.go | 150 +++++++++++++----- 4 files changed, 180 insertions(+), 79 deletions(-) diff --git a/internal/cmd/base/flags.go b/internal/cmd/base/flags.go index b864518a67..783beb8748 100644 --- a/internal/cmd/base/flags.go +++ b/internal/cmd/base/flags.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/kr/pretty" "github.com/posener/complete" + "google.golang.org/protobuf/types/known/wrapperspb" ) // FlagExample is an interface which declares an example value. @@ -909,6 +910,9 @@ func (f *FlagSet) Var(value flag.Value, name, usage string) { // Value parts so that validation can happen at parsing time. If you don't want // this kind of behavior, simply combine them, or set KvSplit to false. // +// If KeyOnlyAllowed is true then it is valid to parse an input with only a key +// segment and no value. +// // If KeyDelimiter is non-nil (along with KvSplit being true), the string will // be used to split the key. Otherwise, the Keys will be a single-element slice // containing the full value. @@ -924,6 +928,7 @@ type CombinationSliceVar struct { Target *[]CombinedSliceFlagValue Completion complete.Predictor KvSplit bool + KeyOnlyAllowed bool KeyDelimiter *string ProtoCompatKey bool } @@ -933,7 +938,7 @@ func (f *FlagSet) CombinationSliceVar(i *CombinationSliceVar) { Name: i.Name, Aliases: i.Aliases, Usage: i.Usage, - Value: newCombinedSliceValue(i.Name, i.Target, i.Hidden, i.KvSplit, i.KeyDelimiter, i.ProtoCompatKey), + Value: newCombinedSliceValue(i.Name, i.Target, i.Hidden, i.KvSplit, i.KeyOnlyAllowed, i.KeyDelimiter, i.ProtoCompatKey), Completion: i.Completion, }) } @@ -943,24 +948,26 @@ func (f *FlagSet) CombinationSliceVar(i *CombinationSliceVar) { type CombinedSliceFlagValue struct { Name string Keys []string - Value string + Value *wrapperspb.StringValue } type combinedSliceValue struct { name string hidden bool kvSplit bool + keyOnlyAllowed bool keyDelimiter *string protoCompatKey bool target *[]CombinedSliceFlagValue } -func newCombinedSliceValue(name string, target *[]CombinedSliceFlagValue, hidden, kvSplit bool, keyDelimiter *string, protoCompatKey bool) *combinedSliceValue { +func newCombinedSliceValue(name string, target *[]CombinedSliceFlagValue, hidden, kvSplit, keyOnlyAllowed bool, keyDelimiter *string, protoCompatKey bool) *combinedSliceValue { return &combinedSliceValue{ name: name, hidden: hidden, kvSplit: kvSplit, keyDelimiter: keyDelimiter, + keyOnlyAllowed: keyOnlyAllowed, protoCompatKey: protoCompatKey, target: target, } @@ -971,21 +978,27 @@ var protoIdentifierRegex = regexp.MustCompile("^[a-zA-Z][A-Za-z0-9_]*$") func (c *combinedSliceValue) Set(val string) error { ret := CombinedSliceFlagValue{ Name: c.name, - Value: strings.TrimSpace(val), + Value: wrapperspb.String(strings.TrimSpace(val)), } if c.kvSplit { - kv := strings.SplitN(ret.Value, "=", 2) + kv := strings.SplitN(ret.Value.GetValue(), "=", 2) switch len(kv) { case 0: + // This shouldn't happen + return fmt.Errorf("unexpected length of string after splitting") case 1: - ret.Value = strings.TrimSpace(kv[0]) + if !c.keyOnlyAllowed { + return fmt.Errorf("key-only value provided but not supported for this flag") + } + ret.Keys = []string{kv[0]} + ret.Value = nil default: ret.Keys = []string{kv[0]} if c.keyDelimiter != nil { ret.Keys = strings.Split(kv[0], *c.keyDelimiter) } - ret.Value = strings.TrimSpace(kv[1]) + ret.Value = wrapperspb.String(strings.TrimSpace(kv[1])) } } @@ -1002,9 +1015,14 @@ func (c *combinedSliceValue) Set(val string) error { } } - var err error - if ret.Value, err = parseutil.ParsePath(ret.Value); err != nil && !errors.Is(err, parseutil.ErrNotAUrl) { - return fmt.Errorf("error checking if value is a path: %w", err) + if ret.Value != nil { + pathParsedValue, err := parseutil.ParsePath(ret.Value.GetValue()) + if err != nil && !errors.Is(err, parseutil.ErrNotAUrl) { + return fmt.Errorf("error checking if value is a path: %w", err) + } + if pathParsedValue != "" { + ret.Value = wrapperspb.String(pathParsedValue) + } } *c.target = append(*c.target, ret) diff --git a/internal/cmd/commands/credentiallibrariescmd/vault_funcs.go b/internal/cmd/commands/credentiallibrariescmd/vault_funcs.go index 9831fa7f9c..d066a2cde5 100644 --- a/internal/cmd/commands/credentiallibrariescmd/vault_funcs.go +++ b/internal/cmd/commands/credentiallibrariescmd/vault_funcs.go @@ -118,7 +118,7 @@ func extraVaultFlagHandlingFuncImpl(c *VaultCommand, _ *base.FlagSets, opts *[]c switch len(c.flagCredentialMapping) { case 0: case 1: - if len(c.flagCredentialMapping[0].Keys) == 0 && c.flagCredentialMapping[0].Value == "null" { + if len(c.flagCredentialMapping[0].Keys) == 1 && c.flagCredentialMapping[0].Keys[0] == "null" && c.flagCredentialMapping[0].Value == nil { *opts = append(*opts, credentiallibraries.DefaultCredentialMappingOverrides()) break } @@ -127,16 +127,16 @@ func extraVaultFlagHandlingFuncImpl(c *VaultCommand, _ *base.FlagSets, opts *[]c mappings := make(map[string]any, len(c.flagCredentialMapping)) for _, mapping := range c.flagCredentialMapping { switch { - case len(mapping.Keys) != 1 || mapping.Keys[0] == "" || mapping.Value == "": + case len(mapping.Keys) != 1 || mapping.Keys[0] == "" || mapping.Value == nil || mapping.Value.GetValue() == "": // mapping override does not support key segments (e.g. 'x.y=z') c.UI.Error("Credential mapping override must be in the format 'key=value', 'key=null' to clear field or 'null' to clear all.") return false - case mapping.Value == "null": + case mapping.Value.GetValue() == "null": // user provided 'key=null' indicating the field specific override should // be cleared, set map value to nil mappings[mapping.Keys[0]] = nil default: - mappings[mapping.Keys[0]] = mapping.Value + mappings[mapping.Keys[0]] = mapping.Value.GetValue() } } *opts = append(*opts, credentiallibraries.WithCredentialMappingOverrides(mappings)) diff --git a/internal/cmd/common/flags.go b/internal/cmd/common/flags.go index dc43d3d5ba..7935cbb1cb 100644 --- a/internal/cmd/common/flags.go +++ b/internal/cmd/common/flags.go @@ -164,9 +164,11 @@ func PopulateCombinedSliceFlagValue(input CombinedSliceFlagValuePopulationInput) KvSplit: true, KeyDelimiter: &keyDelimiter, ProtoCompatKey: true, + KeyOnlyAllowed: true, Usage: fmt.Sprintf( "A key=value pair to add to the request's %s map. "+ - "The type is automatically inferred. Use -string-%s, -bool-%s, or -num-%s if the type needs to be overridden. "+ + "This can also be a key value only which will set a JSON null as the value. "+ + "If a value is provided, the type is automatically inferred. Use -string-%s, -bool-%s, or -num-%s if the type needs to be overridden. "+ "Can be specified multiple times. "+ "Supports sourcing values from files via \"file://\" and env vars via \"env://\".", input.FullPopulationInputName, @@ -263,26 +265,36 @@ func HandleAttributeFlags(c *base.Command, suffix, fullField string, sepFields [ // First, perform any needed parsing if we are given the type switch field.Name { case "num-" + suffix: - // JSON treats all numbers equally, however, we will try to be a - // little better so that we don't include decimals if we don't need - // to (and don't have to worry about precision if not necessary) - if strings.Contains(field.Value, ".") { - val, err = strconv.ParseFloat(field.Value, 64) + if field.Value == nil { + return fmt.Errorf("num-%s flag requires a value", suffix) + } + switch { + case strings.Contains(field.Value.GetValue(), "."): + // JSON treats all numbers equally, however, we will try to be a + // little better so that we don't include decimals if we don't need + // to (and don't have to worry about precision if not necessary) + val, err = strconv.ParseFloat(field.Value.GetValue(), 64) if err != nil { return fmt.Errorf("error parsing value %q as a float: %w", field.Value, err) } - } else { - val, err = strconv.ParseInt(field.Value, 10, 64) + default: + val, err = strconv.ParseInt(field.Value.GetValue(), 10, 64) if err != nil { return fmt.Errorf("error parsing value %q as an integer: %w", field.Value, err) } } case "string-" + suffix: - val = strings.Trim(field.Value, `"`) + if field.Value == nil { + return fmt.Errorf("string-%s flag requires a value", suffix) + } + val = strings.Trim(field.Value.GetValue(), `"`) case "bool-" + suffix: - switch field.Value { + if field.Value == nil { + return fmt.Errorf("bool-%s flag requires a value", suffix) + } + switch field.Value.GetValue() { case "true": val = true case "false": @@ -295,44 +307,46 @@ func HandleAttributeFlags(c *base.Command, suffix, fullField string, sepFields [ // In this case, use heuristics to just do the right thing the vast // majority of the time switch { - case field.Value == "null": // Explicit null, we want to set to a null value to clear it + case field.Value == nil: // Key-only, set to null + + case field.Value.GetValue() == "null": // Explicit null, we want to set to a null value to clear it val = nil - case field.Value == "true": // bool true + case field.Value.GetValue() == "true": // bool true val = true - case field.Value == "false": // bool false + case field.Value.GetValue() == "false": // bool false val = false - case strings.HasPrefix(field.Value, `"`): // explicitly quoted string - val = strings.Trim(field.Value, `"`) + case strings.HasPrefix(field.Value.GetValue(), `"`): // explicitly quoted string + val = strings.Trim(field.Value.GetValue(), `"`) - case jsonNumberRegex.MatchString(strings.Trim(field.Value, `"`)): // number + case jsonNumberRegex.MatchString(strings.Trim(field.Value.GetValue(), `"`)): // number // Same logic as above - if strings.Contains(field.Value, ".") { - val, err = strconv.ParseFloat(field.Value, 64) + if strings.Contains(field.Value.GetValue(), ".") { + val, err = strconv.ParseFloat(field.Value.GetValue(), 64) if err != nil { return fmt.Errorf("error parsing value %q as a float: %w", field.Value, err) } } else { - val, err = strconv.ParseInt(field.Value, 10, 64) + val, err = strconv.ParseInt(field.Value.GetValue(), 10, 64) if err != nil { return fmt.Errorf("error parsing value %q as an integer: %w", field.Value, err) } } - case strings.HasPrefix(field.Value, "["): // serialized JSON array + case strings.HasPrefix(field.Value.GetValue(), "["): // serialized JSON array var s []any - u := json.NewDecoder(bytes.NewBufferString(field.Value)) + u := json.NewDecoder(bytes.NewBufferString(field.Value.GetValue())) u.UseNumber() if err := u.Decode(&s); err != nil { return fmt.Errorf("error parsing value %q as a json array: %w", field.Value, err) } val = s - case strings.HasPrefix(field.Value, "{"): // serialized JSON map + case strings.HasPrefix(field.Value.GetValue(), "{"): // serialized JSON map var m map[string]any - u := json.NewDecoder(bytes.NewBufferString(field.Value)) + u := json.NewDecoder(bytes.NewBufferString(field.Value.GetValue())) u.UseNumber() if err := u.Decode(&m); err != nil { return fmt.Errorf("error parsing value %q as a json map: %w", field.Value, err) @@ -343,6 +357,7 @@ func HandleAttributeFlags(c *base.Command, suffix, fullField string, sepFields [ // Default is to treat as a string value val = field.Value } + default: return fmt.Errorf("unknown flag %q", field.Name) } @@ -374,7 +389,7 @@ func HandleAttributeFlags(c *base.Command, suffix, fullField string, sepFields [ default: // It's not a slice, so create a new slice with the - // exisitng and new values + // existing and new values currMap[segment] = []any{t, val} } diff --git a/internal/cmd/common/flags_test.go b/internal/cmd/common/flags_test.go index 705f523019..4ad357e985 100644 --- a/internal/cmd/common/flags_test.go +++ b/internal/cmd/common/flags_test.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/boundary/internal/cmd/base" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/wrapperspb" ) // TestPopulateAttrFlags tests common patterns we'll actually be using. Note @@ -29,12 +30,12 @@ func TestPopulateAttrFlags(t *testing.T) { { Name: "string-attr", Keys: []string{"foo"}, - Value: "bar", + Value: wrapperspb.String("bar"), }, { Name: "string-attr", Keys: []string{"bar"}, - Value: `"baz"`, + Value: wrapperspb.String(`"baz"`), }, }, }, @@ -45,12 +46,12 @@ func TestPopulateAttrFlags(t *testing.T) { { Name: "num-attr", Keys: []string{"foo"}, - Value: "-1.2", + Value: wrapperspb.String("-1.2"), }, { Name: "num-attr", Keys: []string{"bar"}, - Value: "5", + Value: wrapperspb.String("5"), }, }, }, @@ -61,15 +62,40 @@ func TestPopulateAttrFlags(t *testing.T) { { Name: "bool-attr", Keys: []string{"foo"}, - Value: "true", + Value: wrapperspb.String("true"), }, { Name: "bool-attr", Keys: []string{"bar"}, - Value: "false", + Value: wrapperspb.String("false"), }, }, }, + { + name: "key-only", + args: []string{"-attr", "foo"}, + expected: []base.CombinedSliceFlagValue{ + { + Name: "attr", + Keys: []string{"foo"}, + }, + }, + }, + { + name: "bad-key-only-bool", + args: []string{"-bool-attr", "foo"}, + expectedErr: `invalid value "foo" for flag -bool-attr: key-only value provided but not supported for this flag`, + }, + { + name: "bad-key-only-num", + args: []string{"-num-attr", "foo"}, + expectedErr: `invalid value "foo" for flag -num-attr: key-only value provided but not supported for this flag`, + }, + { + name: "bad-key-only-string", + args: []string{"-string-attr", "foo"}, + expectedErr: `invalid value "foo" for flag -string-attr: key-only value provided but not supported for this flag`, + }, { name: "mixed", args: []string{"-num-attr", "foo=9820", "-string-attr", "bar=9820", "-attr", "baz=9820"}, @@ -77,17 +103,17 @@ func TestPopulateAttrFlags(t *testing.T) { { Name: "num-attr", Keys: []string{"foo"}, - Value: "9820", + Value: wrapperspb.String("9820"), }, { Name: "string-attr", Keys: []string{"bar"}, - Value: "9820", + Value: wrapperspb.String("9820"), }, { Name: "attr", Keys: []string{"baz"}, - Value: "9820", + Value: wrapperspb.String("9820"), }, }, }, @@ -98,17 +124,17 @@ func TestPopulateAttrFlags(t *testing.T) { { Name: "num-attr", Keys: []string{"foo", "bar", "baz"}, - Value: "9820", + Value: wrapperspb.String("9820"), }, { Name: "string-attr", Keys: []string{"bar", "baz", "foo"}, - Value: "9820", + Value: wrapperspb.String("9820"), }, { Name: "attr", Keys: []string{"baz", "foo", "bar"}, - Value: "9820", + Value: wrapperspb.String("9820"), }, }, }, @@ -129,7 +155,7 @@ func TestPopulateAttrFlags(t *testing.T) { { Name: "attr", Keys: []string{"filter"}, - Value: "tagName eq 'application:south-seas'", + Value: wrapperspb.String("tagName eq 'application:south-seas'"), }, }, }, @@ -195,12 +221,12 @@ func TestHandleAttributeFlags(t *testing.T) { { Name: "string-%s", Keys: []string{"foo"}, - Value: "bar", + Value: wrapperspb.String("bar"), }, { Name: "string-%s", Keys: []string{"bar"}, - Value: `"baz"`, + Value: wrapperspb.String(`"baz"`), }, }, expectedMap: map[string]any{ @@ -214,12 +240,12 @@ func TestHandleAttributeFlags(t *testing.T) { { Name: "num-%s", Keys: []string{"foo"}, - Value: "-1.2", + Value: wrapperspb.String("-1.2"), }, { Name: "num-%s", Keys: []string{"bar"}, - Value: "5", + Value: wrapperspb.String("5"), }, }, expectedMap: map[string]any{ @@ -233,7 +259,7 @@ func TestHandleAttributeFlags(t *testing.T) { { Name: "num-%s", Keys: []string{"foo"}, - Value: "-15d.2", + Value: wrapperspb.String("-15d.2"), }, }, expectedErr: "as a float", @@ -244,7 +270,7 @@ func TestHandleAttributeFlags(t *testing.T) { { Name: "num-%s", Keys: []string{"foo"}, - Value: "-15d3", + Value: wrapperspb.String("-15d3"), }, }, expectedErr: "as an int", @@ -255,12 +281,12 @@ func TestHandleAttributeFlags(t *testing.T) { { Name: "bool-%s", Keys: []string{"foo"}, - Value: "true", + Value: wrapperspb.String("true"), }, { Name: "bool-%s", Keys: []string{"bar"}, - Value: "false", + Value: wrapperspb.String("false"), }, }, expectedMap: map[string]any{ @@ -274,64 +300,106 @@ func TestHandleAttributeFlags(t *testing.T) { { Name: "bool-%s", Keys: []string{"foo"}, - Value: "t", + Value: wrapperspb.String("t"), }, }, expectedErr: "as a bool", }, + { + name: "key-only-bare", + args: []base.CombinedSliceFlagValue{ + { + Name: "%s", + Keys: []string{"foo"}, + }, + }, + expectedMap: map[string]any{ + "foo": nil, + }, + }, + { + name: "bad-key-only-bool", + args: []base.CombinedSliceFlagValue{ + { + Name: "bool-%s", + Keys: []string{"foo"}, + }, + }, + expectedErr: `does not support key-only values`, + }, + { + name: "bad-key-only-num", + args: []base.CombinedSliceFlagValue{ + { + Name: "num-%s", + Keys: []string{"foo"}, + }, + }, + expectedErr: `does not support key-only values`, + }, + { + name: "bad-key-only-string", + args: []base.CombinedSliceFlagValue{ + { + Name: "string-%s", + Keys: []string{"foo"}, + }, + }, + expectedErr: `does not support key-only values`, + }, { name: "attr-only", args: []base.CombinedSliceFlagValue{ { Name: "%s", Keys: []string{"b1"}, - Value: "true", + Value: wrapperspb.String("true"), }, { Name: "%s", Keys: []string{"b2"}, - Value: "false", + Value: wrapperspb.String("false"), }, { Name: "%s", Keys: []string{"s1"}, - Value: "scoopde", + Value: wrapperspb.String("scoopde"), }, { Name: "%s", Keys: []string{"s2"}, - Value: `"woop"`, + Value: wrapperspb.String(`"woop"`), }, { Name: "%s", Keys: []string{"n1"}, - Value: "-1.2", + Value: wrapperspb.String("-1.2"), }, { Name: "%s", Keys: []string{"n2"}, - Value: "5", + Value: wrapperspb.String("5"), }, { Name: "%s", Keys: []string{"a"}, - Value: `["foo", 1.5, true, ["bar"], {"hip": "hop"}]`, + Value: wrapperspb.String(`["foo", 1.5, true, ["bar"], {"hip": "hop"}]`), }, { Name: "%s", Keys: []string{"nil"}, - Value: "null", + Value: wrapperspb.String("null"), }, { Name: "%s", Keys: []string{"m"}, - Value: `{"b": true, "n": 6, "s": "scoopde", "a": ["bar"], "m": {"hip": "hop"}}`, + Value: wrapperspb.String(`{"b": true, "n": 6, "s": "scoopde", "a": ["bar"], "m": {"hip": "hop"}}`), }, }, expectedMap: map[string]any{ "b1": true, "b2": false, - "s1": "scoopde", + "s1": wrapperspb.String("scoopde"), "s2": "woop", "n1": float64(-1.2), "n2": int64(5), @@ -358,43 +426,43 @@ func TestHandleAttributeFlags(t *testing.T) { { Name: "%s", Keys: []string{"bools"}, - Value: "true", + Value: wrapperspb.String("true"), }, { Name: "%s", Keys: []string{"bools"}, - Value: "false", + Value: wrapperspb.String("false"), }, { Name: "%s", Keys: []string{"strings", "s1"}, - Value: "scoopde", + Value: wrapperspb.String("scoopde"), }, { Name: "%s", - Keys: []string{"strings", "s2"}, - Value: `"woop"`, + Keys: []string{"strings", "s2"}, // Overwritten below + Value: wrapperspb.String(`"woop"`), }, { Name: "%s", Keys: []string{"numbers", "reps"}, - Value: "-1.2", + Value: wrapperspb.String("-1.2"), }, { Name: "%s", Keys: []string{"numbers", "reps"}, - Value: "5", + Value: wrapperspb.String("5"), }, { Name: "%s", Keys: []string{"strings", "s2"}, // This will overwrite above! - Value: "null", + Value: wrapperspb.String("null"), }, }, expectedMap: map[string]any{ "bools": []any{true, false}, "strings": map[string]any{ - "s1": "scoopde", + "s1": wrapperspb.String("scoopde"), "s2": nil, }, "numbers": map[string]any{