diff --git a/internal/cmd/base/flags.go b/internal/cmd/base/flags.go index b6080b0756..0f69ebd871 100644 --- a/internal/cmd/base/flags.go +++ b/internal/cmd/base/flags.go @@ -226,6 +226,70 @@ func (i *uint64Value) String() string { return strconv.FormatUint(uint64(*i.tar func (i *uint64Value) Example() string { return "uint" } func (i *uint64Value) Hidden() bool { return i.hidden } +// -- Uint16Var and uint16Value +type Uint16Var struct { + Name string + Aliases []string + Usage string + Default uint16 + Hidden bool + EnvVar string + Target *uint16 + Completion complete.Predictor +} + +func (f *FlagSet) Uint16Var(i *Uint16Var) { + initial := i.Default + if v, exist := os.LookupEnv(i.EnvVar); exist { + if i, err := strconv.ParseUint(v, 0, 16); err == nil { + initial = uint16(i) + } + } + + def := "" + if i.Default != 0 { + strconv.FormatUint(uint64(i.Default), 10) + } + + f.VarFlag(&VarFlag{ + Name: i.Name, + Aliases: i.Aliases, + Usage: i.Usage, + Default: def, + EnvVar: i.EnvVar, + Value: newUint16Value(initial, i.Target, i.Hidden), + Completion: i.Completion, + }) +} + +type uint16Value struct { + hidden bool + target *uint16 +} + +func newUint16Value(def uint16, target *uint16, hidden bool) *uint16Value { + *target = def + return &uint16Value{ + hidden: hidden, + target: target, + } +} + +func (i *uint16Value) Set(s string) error { + v, err := strconv.ParseUint(s, 0, 16) + if err != nil { + return err + } + + *i.target = uint16(v) + return nil +} + +func (i *uint16Value) Get() any { return uint64(*i.target) } +func (i *uint16Value) String() string { return strconv.FormatUint(uint64(*i.target), 10) } +func (i *uint16Value) Example() string { return "uint" } +func (i *uint16Value) Hidden() bool { return i.hidden } + // -- StringVar and stringValue type StringVar struct { Name string diff --git a/internal/cmd/base/flags_test.go b/internal/cmd/base/flags_test.go index ae646082ee..91bf0036f9 100644 --- a/internal/cmd/base/flags_test.go +++ b/internal/cmd/base/flags_test.go @@ -4,8 +4,10 @@ package base import ( + "os" "testing" + "github.com/mitchellh/cli" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -152,3 +154,90 @@ func TestFlagSet_StringSliceMapVar_NullCheck(t *testing.T) { }) } } + +func TestUint16Var(t *testing.T) { + t.Parallel() + + target := uint16(0) + + flagSets := NewFlagSets(cli.NewMockUi()) + f := flagSets.NewFlagSet("testset") + f.Uint16Var(&Uint16Var{ + Name: "test_name", + Aliases: []string{"test_alias1"}, + Usage: "test_usage", + Default: 1, + Hidden: false, + Target: &target, + }) + require.Equal(t, uint16(1), target) // Should immediately default. + + // Value that overflows uint16 should error. + err := flagSets.Parse([]string{"-test_name", "66000"}) + require.EqualError(t, err, "invalid value \"66000\" for flag -test_name: strconv.ParseUint: parsing \"66000\": value out of range") + require.Equal(t, uint16(1), target) + + // Value that overflows uint16 (via alias) should error. + err = flagSets.Parse([]string{"-test_alias1", "66000"}) + require.EqualError(t, err, "invalid value \"66000\" for flag -test_alias1: strconv.ParseUint: parsing \"66000\": value out of range") + require.Equal(t, uint16(1), target) + + // Negative value should error. + err = flagSets.Parse([]string{"-test_name", "-1"}) + require.EqualError(t, err, "invalid value \"-1\" for flag -test_name: strconv.ParseUint: parsing \"-1\": invalid syntax") + require.Equal(t, uint16(1), target) + + // Negative value (via alias) should error. + err = flagSets.Parse([]string{"-test_alias1", "-1"}) + require.EqualError(t, err, "invalid value \"-1\" for flag -test_alias1: strconv.ParseUint: parsing \"-1\": invalid syntax") + require.Equal(t, uint16(1), target) + + // Valid value should be put into target. + err = flagSets.Parse([]string{"-test_name", "123"}) + require.NoError(t, err) + require.Equal(t, uint16(123), target) + + // Valid value (using alias) should be put into target. + err = flagSets.Parse([]string{"-test_alias1", "456"}) + require.NoError(t, err) + require.Equal(t, uint16(456), target) + + // Env var tests. + envTarget := uint16(0) + envVarName := "test_uint16_env_var" + + envFlagSets := NewFlagSets(cli.NewMockUi()) + ef := envFlagSets.NewFlagSet("env_testset") + + require.NoError(t, os.Setenv(envVarName, "66000")) + ef.Uint16Var(&Uint16Var{ + Name: "test_env_name1", + Default: 1, + EnvVar: envVarName, + Target: &envTarget, + }) + require.Equal(t, uint16(1), envTarget) // Should be set to default because env value parse will have failed. + require.NoError(t, os.Unsetenv(envVarName)) + envTarget = uint16(0) + + require.NoError(t, os.Setenv(envVarName, "-1")) + ef.Uint16Var(&Uint16Var{ + Name: "test_env_name2", + Default: 1, + EnvVar: envVarName, + Target: &envTarget, + }) + require.Equal(t, uint16(1), envTarget) // Should be set to default because env value parse will have failed. + require.NoError(t, os.Unsetenv(envVarName)) + envTarget = uint16(0) + + require.NoError(t, os.Setenv(envVarName, "123")) + ef.Uint16Var(&Uint16Var{ + Name: "test_env_name3", + Default: 1, + EnvVar: envVarName, + Target: &envTarget, + }) + require.Equal(t, uint16(123), envTarget) // Should be set to what was set in env. + require.NoError(t, os.Unsetenv(envVarName)) +}