From 6afd2c96a5df301b91e824bb6ad3a4d69b32434f Mon Sep 17 00:00:00 2001 From: Elim Tsiagbey Date: Mon, 17 Jul 2023 11:55:46 -0400 Subject: [PATCH] feat: Add custom summary error type (#3459) Add custom error type to to decode and encode `Errors` field Added some extra checks to handle existing BSR summary types with `Error` fields --- internal/bsr/types.go | 51 +++++++++++++++--- internal/bsr/types_test.go | 106 +++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 6 deletions(-) create mode 100644 internal/bsr/types_test.go diff --git a/internal/bsr/types.go b/internal/bsr/types.go index 4cada819d7..87fe2a58b1 100644 --- a/internal/bsr/types.go +++ b/internal/bsr/types.go @@ -5,10 +5,49 @@ package bsr import ( "context" + "encoding/json" "errors" "time" ) +type SummaryError struct { + Message string `json:"message"` +} + +func (e *SummaryError) Error() string { + return e.Message +} + +func (e *SummaryError) MarshalJSON() ([]byte, error) { + return json.Marshal(e.Message) +} + +func (e *SummaryError) UnmarshalJSON(data []byte) error { + var rawMessage json.RawMessage + var message string + + err := json.Unmarshal(data, &rawMessage) + if err != nil { + return err + } + + switch { + case len(rawMessage) > 0 && string(rawMessage) == `"{}"`: + e.Message = "" + case len(rawMessage) > 0 && string(rawMessage) == `"null"`: + e.Message = "" + case len(rawMessage) <= 0: + e.Message = "" + default: + if err := json.Unmarshal(data, &message); err != nil { + return err + } + e.Message = message + } + + return nil +} + // BaseSummary contains the common fields of all summary types. type BaseSummary struct { Id string @@ -49,7 +88,7 @@ type BaseSessionSummary struct { ConnectionCount uint64 StartTime time.Time EndTime time.Time - Errors string + Errors SummaryError } // SessionSummary contains statistics for a session container @@ -89,12 +128,12 @@ func (b *BaseSessionSummary) GetConnectionCount() uint64 { // GetErrors returns errors. func (b *BaseSessionSummary) GetErrors() error { - return errors.New(b.Errors) + return errors.New(b.Errors.Message) } // SetErrors sets errors. func (b *BaseSessionSummary) SetErrors(e error) { - b.Errors = e.Error() + b.Errors = SummaryError{Message: e.Error()} } // BaseChannelSummary encapsulates data for a channel, including its id, channel type, @@ -170,7 +209,7 @@ type BaseConnectionSummary struct { EndTime time.Time BytesUp uint64 BytesDown uint64 - Errors string + Errors SummaryError } // ConnectionSummary contains statistics for a connection container @@ -224,10 +263,10 @@ func (b *BaseConnectionSummary) GetBytesDown() uint64 { // GetErrors returns errors. func (b *BaseConnectionSummary) GetErrors() error { - return errors.New(b.Errors) + return &b.Errors } // SetErrors sets errors. func (b *BaseConnectionSummary) SetErrors(e error) { - b.Errors = e.Error() + b.Errors = SummaryError{Message: e.Error()} } diff --git a/internal/bsr/types_test.go b/internal/bsr/types_test.go new file mode 100644 index 0000000000..eae6403c89 --- /dev/null +++ b/internal/bsr/types_test.go @@ -0,0 +1,106 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package bsr_test + +import ( + "encoding/json" + "testing" + + "github.com/hashicorp/boundary/internal/bsr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSummaryError_MarshalJSON(t *testing.T) { + cases := []struct { + name string + in string + want string + }{ + { + name: "error string", + in: "error", + want: "\"error\"", + }, + { + name: "empty string", + in: "", + want: "\"\"", + }, + { + name: "empty object", + in: `{}`, + want: "\"{}\"", + }, + { + name: "empty object", + in: `null`, + want: "\"null\"", + }, + { + name: "key value object", + in: `{"message": "failed to load"}`, + want: "\"{\\\"message\\\": \\\"failed to load\\\"}\"", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + summaryErr := bsr.SummaryError{Message: tc.in} + + got, err := summaryErr.MarshalJSON() + require.NoError(t, err) + + assert.Equal(t, tc.want, string(got)) + }) + } +} + +func TestSummaryError_UnmarshalJSON(t *testing.T) { + cases := []struct { + name string + in any + want string + }{ + { + name: "error string", + in: "error", + want: "error", + }, + { + name: "empty string", + in: "", + want: "", + }, + { + name: "empty object", + in: `{}`, + want: "", + }, + { + name: "empty object", + in: `null`, + want: "", + }, + { + name: "key value object", + in: `{"message": "failed to load"}`, + want: "{\"message\": \"failed to load\"}", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var summaryErr bsr.SummaryError + + input, err := json.Marshal(tc.in) + require.NoError(t, err) + + err = summaryErr.UnmarshalJSON(input) + require.NoError(t, err) + + assert.Equal(t, tc.want, summaryErr.Message) + }) + } +}