diff --git a/api/client.go b/api/client.go index cfb20229e8..3a71fd57ee 100644 --- a/api/client.go +++ b/api/client.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "math" "net" "net/http" "net/url" @@ -308,6 +309,12 @@ func (c *Config) ReadEnvironment() error { if err != nil { return err } + // maxRetries is a 32-bit unsigned integer stored inside an uint64. + // c.MaxRetries is a signed integer that is at least 32 bits in size. + // Check bounds against lowest denominator before casting. + if maxRetries > math.MaxInt32 { + return fmt.Errorf("max retries must be less than or equal to %d", math.MaxInt32) + } c.MaxRetries = int(maxRetries) } diff --git a/api/client_test.go b/api/client_test.go index a886bf3224..2e8bea6c7f 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -4,9 +4,13 @@ package api import ( + "math" + "os" + "strconv" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestConfigSetAddress(t *testing.T) { @@ -79,3 +83,71 @@ func TestConfigSetAddress(t *testing.T) { }) } } + +func TestReadEnvironmentMaxRetries(t *testing.T) { + tests := []struct { + name string + inp string + expMaxRetries int + expErrContains string + }{ + { + name: "invalidNaN", + inp: "bad", + expErrContains: "strconv.ParseUint: parsing \"bad\": invalid syntax", + }, + { + name: "invalidNegativeNumber", + inp: "-1", + expErrContains: "strconv.ParseUint: parsing \"-1\": invalid syntax", + }, + { + name: "invalidGreaterThanUint32", + inp: strconv.Itoa(math.MaxUint32 + 10), + expErrContains: "strconv.ParseUint: parsing \"4294967305\": value out of range", + }, + { + name: "invalidGreaterThanInt32", + inp: strconv.Itoa(math.MaxInt32 + 10), + expErrContains: "max retries must be less than or equal to 2147483647", + }, + { + name: "success1", + inp: "0", + expMaxRetries: 0, + }, + { + name: "success2", + inp: "10000", + expMaxRetries: 10000, + }, + { + name: "successMaxInt32", + inp: strconv.Itoa(math.MaxInt32), + expMaxRetries: math.MaxInt32, + }, + { + name: "successNothing", + inp: "", + expMaxRetries: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Setenv(EnvBoundaryMaxRetries, tt.inp) + t.Cleanup(func() { os.Unsetenv(EnvBoundaryMaxRetries) }) + + var c Config + err := c.ReadEnvironment() + if tt.expErrContains != "" { + require.ErrorContains(t, err, tt.expErrContains) + require.Equal(t, 0, c.MaxRetries) + return + } + + require.NoError(t, err) + require.Equal(t, tt.expMaxRetries, c.MaxRetries) + }) + } +}