diff --git a/internal/bsr/types.go b/internal/bsr/types.go index 87fe2a58b1..c439f075c7 100644 --- a/internal/bsr/types.go +++ b/internal/bsr/types.go @@ -128,6 +128,9 @@ func (b *BaseSessionSummary) GetConnectionCount() uint64 { // GetErrors returns errors. func (b *BaseSessionSummary) GetErrors() error { + if len(b.Errors.Message) == 0 { + return nil + } return errors.New(b.Errors.Message) } @@ -263,6 +266,9 @@ func (b *BaseConnectionSummary) GetBytesDown() uint64 { // GetErrors returns errors. func (b *BaseConnectionSummary) GetErrors() error { + if len(b.Errors.Message) == 0 { + return nil + } return &b.Errors } diff --git a/internal/bsr/types_test.go b/internal/bsr/types_test.go index eae6403c89..b67716f98a 100644 --- a/internal/bsr/types_test.go +++ b/internal/bsr/types_test.go @@ -5,6 +5,7 @@ package bsr_test import ( "encoding/json" + "errors" "testing" "github.com/hashicorp/boundary/internal/bsr" @@ -104,3 +105,31 @@ func TestSummaryError_UnmarshalJSON(t *testing.T) { }) } } + +func TestBaseSessionSummary_GetErrors(t *testing.T) { + cases := []struct { + name string + in error + want error + }{ + { + name: "error string", + in: errors.New("error"), + want: errors.New("error"), + }, + { + name: "empty string should return nil error", + in: errors.New(""), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + summary := bsr.BaseSessionSummary{} + summary.SetErrors(tc.in) + + got := summary.GetErrors() + assert.Equal(t, tc.want, got) + }) + } +}