diff --git a/internal/cmd/base/flags.go b/internal/cmd/base/flags.go index 1d54d723d7..34dfbc87c8 100644 --- a/internal/cmd/base/flags.go +++ b/internal/cmd/base/flags.go @@ -548,7 +548,11 @@ func appendDurationSuffix(s string) string { return s + "s" } -// -- StringSliceVar and stringSliceValue +// StringSliceVar reads in parameters from the same flag into a string array. +// Setting NullCheck enables the following behavior: the function will be run +// whenever a "null" is seen in the input to determine whether or not it is +// allowed (and erroring if not); and even if allowed generally, the flag will +// ensure that "null" is the only value passed. type StringSliceVar struct { Name string Aliases []string @@ -557,6 +561,7 @@ type StringSliceVar struct { Hidden bool EnvVar string Target *[]string + NullCheck func() bool Completion complete.Predictor } @@ -581,25 +586,53 @@ func (f *FlagSet) StringSliceVar(i *StringSliceVar) { Usage: i.Usage, Default: def, EnvVar: i.EnvVar, - Value: newStringSliceValue(initial, i.Target, i.Hidden), + Value: newStringSliceValue(initial, i.Target, i.Hidden, i.NullCheck), Completion: i.Completion, }) } type stringSliceValue struct { - hidden bool - target *[]string + hidden bool + nullCheck func() bool + target *[]string } -func newStringSliceValue(def []string, target *[]string, hidden bool) *stringSliceValue { +func newStringSliceValue(def []string, target *[]string, hidden bool, nullCheck func() bool) *stringSliceValue { *target = def return &stringSliceValue{ - hidden: hidden, - target: target, + hidden: hidden, + nullCheck: nullCheck, + target: target, } } func (s *stringSliceValue) Set(val string) error { + trimmedVal := strings.TrimSpace(val) + // Don't enable this behavior if nullCheck is not defined + if s.nullCheck != nil { + // If we got null in, go through some checks + if trimmedVal == "null" { + // Ensure null is allowed here + if !s.nullCheck() { + return fmt.Errorf(`"null" is not an allowed value`) + } + // If we have at least one value already and it's not "null" then + // error; we don't check if _all_ are "null" since presumably this + // function prevents there being more values than just "null" if + // that's specified + if len(*s.target) > 0 && (*s.target)[0] != "null" { + return fmt.Errorf(`"null" cannot be combined with other values`) + } + // Set the target to only contain "null" and return + *s.target = []string{"null"} + return nil + } else if len(*s.target) == 1 && (*s.target)[0] == "null" { + // Something came in that isn't "null" but we already have "null" + return fmt.Errorf(`"null" cannot be combined with other values`) + } + } + // Append in all other cases, or if we didn't error above or return just + // "null" *s.target = append(*s.target, strings.TrimSpace(val)) return nil } diff --git a/internal/cmd/common/flags_test.go b/internal/cmd/common/flags_test.go index b1aa85d8b3..80d2dc91b8 100644 --- a/internal/cmd/common/flags_test.go +++ b/internal/cmd/common/flags_test.go @@ -3,6 +3,7 @@ package common import ( "encoding/json" "fmt" + "strings" "testing" "github.com/hashicorp/boundary/internal/cmd/base" @@ -422,3 +423,103 @@ func TestHandleAttributeFlags(t *testing.T) { } } } + +func TestNullableStringSlice(t *testing.T) { + makeStringSlicePointer := func(in ...string) *[]string { + return &in + } + tests := []struct { + name string + cmd string + args []string + expected base.StringSliceVar + expectedErr string + }{ + { + name: "not-set-no-null", + cmd: "add-values", + args: []string{"-val", "foobar", "-val", "barfoo", "-val", "boobaz", "-val", "bazboo"}, + expected: base.StringSliceVar{ + Target: makeStringSlicePointer("foobar", "barfoo", "boobaz", "bazboo"), + }, + }, + { + name: "not-set-null", + cmd: "add-values", + args: []string{"-val", "foobar", "-val", "null", "-val", "boobaz", "-val", "bazboo"}, + expectedErr: `"null" is not an allowed value`, + }, + { + name: "set-no-null", + cmd: "set-values", + args: []string{"-val", "foobar", "-val", "barfoo", "-val", "boobaz", "-val", "bazboo"}, + expected: base.StringSliceVar{ + Target: makeStringSlicePointer("foobar", "barfoo", "boobaz", "bazboo"), + }, + }, + { + name: "set-only-null", + cmd: "set-values", + args: []string{"-val", "null"}, + expected: base.StringSliceVar{ + Target: makeStringSlicePointer("null"), + }, + }, + { + name: "set-null-and-others-beginning", + cmd: "set-values", + args: []string{"-val", "null", "-val", "barfoo", "-val", "boobaz", "-val", "bazboo"}, + expectedErr: `"null" cannot be combined with other values`, + }, + { + name: "set-null-and-others-middle", + cmd: "set-values", + args: []string{"-val", "foobar", "-val", "null", "-val", "boobaz", "-val", "bazboo"}, + expectedErr: `"null" cannot be combined with other values`, + }, + { + name: "set-null-and-others-end", + cmd: "set-values", + args: []string{"-val", "foobar", "-val", "barfoo", "-val", "boobaz", "-val", "null"}, + expectedErr: `"null" cannot be combined with other values`, + }, + { + name: "set-null-multiple", + cmd: "set-values", + args: []string{"-val", "null", "-val", "null"}, + expected: base.StringSliceVar{ + Target: makeStringSlicePointer("null"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + // Note: we do the setup on each run to make sure we aren't carrying + // state over; just like in the real CLI where each run would have + // pristine state. + c := new(base.Command) + flagSet := c.FlagSet(base.FlagSetNone) + f := flagSet.NewFlagSet("Stringsssss") + var target []string + ssVar := &base.StringSliceVar{ + Name: "val", + Target: &target, + NullCheck: func() bool { + return strings.HasPrefix(tt.cmd, "set-") + }, + } + f.StringSliceVar(ssVar) + + err := flagSet.Parse(tt.args) + if tt.expectedErr != "" { + require.Error(err) + assert.Contains(err.Error(), tt.expectedErr) + return + } + require.NoError(err) + assert.Equal(*tt.expected.Target, *ssVar.Target) + }) + } +}