diff --git a/internal/bsr/bsr.go b/internal/bsr/bsr.go index 79cfc2a162..546fa8a668 100644 --- a/internal/bsr/bsr.go +++ b/internal/bsr/bsr.go @@ -41,6 +41,7 @@ type Session struct { Meta *SessionRecordingMeta SessionMeta *SessionMeta + Summary SessionSummary } // NewSession creates a Session container for a given session id. @@ -81,7 +82,7 @@ func NewSession(ctx context.Context, meta *SessionRecordingMeta, sessionMeta *Se return nil, err } - nc, err := newContainer(ctx, sessionContainer, c, keys) + nc, err := newContainer(ctx, SessionContainer, c, keys) if err != nil { return nil, err } @@ -191,7 +192,7 @@ func OpenSession(ctx context.Context, sessionRecordingId string, f storage.FS, k keyPopFn := func(c *container) (*kms.Keys, error) { return c.loadKeys(ctx, keyUnwrapFn) } - cc, err := openContainer(ctx, sessionContainer, c, keyPopFn) + cc, err := openContainer(ctx, SessionContainer, c, keyPopFn) if err != nil { return nil, err } @@ -215,10 +216,22 @@ func OpenSession(ctx context.Context, sessionRecordingId string, f storage.FS, k return nil, err } + af, ok := summaryAllocFuncs.get(meta.Protocol, SessionContainer) + if !ok { + return nil, fmt.Errorf("%s: failed to get summary type", op) + } + + summary := af(ctx) + if err := cc.decodeJsonFile(ctx, fmt.Sprintf(summaryFileNameTemplate, SessionContainer), summary); err != nil { + return nil, err + } + sessionSummary := summary.(SessionSummary) + session := &Session{ container: cc, Meta: meta, SessionMeta: sessionMeta, + Summary: sessionSummary, } return session, nil @@ -235,7 +248,9 @@ type Connection struct { *container multiplexed bool - Meta *ConnectionRecordingMeta + Meta *ConnectionRecordingMeta + session *Session + Summary ConnectionSummary } // NewConnection creates a Connection container for a given connection id. @@ -258,7 +273,7 @@ func (s *Session) NewConnection(ctx context.Context, meta *ConnectionRecordingMe return nil, err } - nc, err := newContainer(ctx, connectionContainer, sc, s.keys) + nc, err := newContainer(ctx, ConnectionContainer, sc, s.keys) if err != nil { return nil, err } @@ -269,6 +284,7 @@ func (s *Session) NewConnection(ctx context.Context, meta *ConnectionRecordingMe container: nc, multiplexed: s.multiplexed, Meta: meta, + session: s, }, nil } @@ -295,7 +311,7 @@ func (s *Session) OpenConnection(ctx context.Context, connId string) (*Connectio keyPopFn := func(c *container) (*kms.Keys, error) { return s.keys, nil } - cc, err := openContainer(ctx, connectionContainer, c, keyPopFn) + cc, err := openContainer(ctx, ConnectionContainer, c, keyPopFn) if err != nil { return nil, err } @@ -314,9 +330,22 @@ func (s *Session) OpenConnection(ctx context.Context, connId string) (*Connectio return nil, err } + af, ok := summaryAllocFuncs.get(s.Meta.Protocol, ConnectionContainer) + if !ok { + return nil, fmt.Errorf("%s: failed to get summary type", op) + } + + summary := af(ctx) + if err := cc.decodeJsonFile(ctx, fmt.Sprintf(summaryFileNameTemplate, ConnectionContainer), summary); err != nil { + return nil, err + } + connectionSummary := summary.(ConnectionSummary) + connection := &Connection{ container: cc, Meta: sm, + session: s, + Summary: connectionSummary, } return connection, nil @@ -345,7 +374,7 @@ func (c *Connection) NewChannel(ctx context.Context, meta *ChannelRecordingMeta) if _, err := c.WriteMeta(ctx, "channel", name); err != nil { return nil, err } - nc, err := newContainer(ctx, channelContainer, sc, c.keys) + nc, err := newContainer(ctx, ChannelContainer, sc, c.keys) if err != nil { return nil, err } @@ -383,7 +412,7 @@ func (c *Connection) OpenChannel(ctx context.Context, chanId string) (*Channel, keyPopFn := func(cn *container) (*kms.Keys, error) { return c.keys, nil } - cc, err := openContainer(ctx, channelContainer, con, keyPopFn) + cc, err := openContainer(ctx, ChannelContainer, con, keyPopFn) if err != nil { return nil, err } @@ -402,9 +431,21 @@ func (c *Connection) OpenChannel(ctx context.Context, chanId string) (*Channel, return nil, err } + af, ok := summaryAllocFuncs.get(c.session.Meta.Protocol, ChannelContainer) + if !ok { + return nil, fmt.Errorf("%s: failed to get summary type", op) + } + + summary := af(ctx) + if err := cc.decodeJsonFile(ctx, fmt.Sprintf(summaryFileNameTemplate, ChannelContainer), summary); err != nil { + return nil, err + } + channelSummary := summary.(ChannelSummary) + channel := &Channel{ container: cc, Meta: sm, + Summary: channelSummary, } return channel, nil @@ -464,7 +505,8 @@ func (c *Connection) Close(ctx context.Context) error { type Channel struct { *container - Meta *ChannelRecordingMeta + Meta *ChannelRecordingMeta + Summary ChannelSummary } // Close closes the Channel container. diff --git a/internal/bsr/bsr_open_test.go b/internal/bsr/bsr_open_test.go index b54d93bd13..e3ca0dc7a1 100644 --- a/internal/bsr/bsr_open_test.go +++ b/internal/bsr/bsr_open_test.go @@ -5,6 +5,7 @@ package bsr import ( "context" + "fmt" "testing" "github.com/hashicorp/boundary/internal/bsr/internal/fstest" @@ -79,14 +80,31 @@ func TestPopulateMeta(t *testing.T) { func TestOpenBSRMethods(t *testing.T) { ctx := context.Background() + protocol := Protocol("TEST") + keys, err := kms.CreateKeys(ctx, kms.TestWrapper(t), "session") require.NoError(t, err) + err = RegisterSummaryAllocFunc(protocol, ChannelContainer, func(ctx context.Context) Summary { + return &BaseChannelSummary{Id: "TEST_CHANNEL_ID", ConnectionRecordingId: "TEST_CONNECTION_RECORDING_ID"} + }) + require.NoError(t, err) + + err = RegisterSummaryAllocFunc(protocol, SessionContainer, func(ctx context.Context) Summary { + return &BaseSessionSummary{Id: "TEST_SESSION_ID", ConnectionCount: 1} + }) + require.NoError(t, err) + + err = RegisterSummaryAllocFunc(protocol, ConnectionContainer, func(ctx context.Context) Summary { + return &BaseConnectionSummary{Id: "TEST_CONNECTION_ID", ChannelCount: 1} + }) + require.NoError(t, err) + f := &fstest.MemFS{} sessionId := "s_01234567890" srm := &SessionRecordingMeta{ Id: "sr_012344567890", - Protocol: Protocol("TEST"), + Protocol: protocol, } sessionMeta := TestSessionMeta(sessionId) @@ -94,12 +112,22 @@ func TestOpenBSRMethods(t *testing.T) { require.NoError(t, err) require.NotNil(t, sesh) + sesh.EncodeSummary(ctx, &BaseChannelSummary{ + Id: "TEST_CHANNEL_ID", + ConnectionRecordingId: "TEST_CONNECTION_RECORDING_ID", + }) + connectionId := "connection" connMeta := &ConnectionRecordingMeta{Id: connectionId} conn, err := sesh.NewConnection(ctx, connMeta) require.NoError(t, err) require.NotNil(t, conn) + conn.EncodeSummary(ctx, &BaseConnectionSummary{ + Id: "TEST_CONNECTION_ID", + ChannelCount: 1, + }) + channelId := "channel" chanMeta := &ChannelRecordingMeta{ Id: channelId, @@ -109,6 +137,11 @@ func TestOpenBSRMethods(t *testing.T) { require.NoError(t, err) require.NotNil(t, ch) + ch.EncodeSummary(ctx, &BaseSessionSummary{ + Id: "TEST_SESSION_ID", + ConnectionCount: 1, + }) + ch.Close(ctx) conn.Close(ctx) sesh.Close(ctx) @@ -293,3 +326,146 @@ func TestOpenChannel(t *testing.T) { }) } } + +func TestOpenBSRMethods_WithoutSummaryAllocFunc(t *testing.T) { + ctx := context.Background() + f := &fstest.MemFS{} + + cases := []struct { + name string + protocol Protocol + sId int + sessionAllocFunc SessionSummary + connectionAllocFunc ConnectionSummary + channelAllocFunc ChannelSummary + expectedError string + wantSessionErr bool + wantConnErr bool + wantChanErr bool + }{ + { + name: "without-session-allocFunc", + protocol: Protocol("TEST_BSR_OPEN_SESSION_PROTOCOL"), + sId: 12345, + sessionAllocFunc: nil, + connectionAllocFunc: &BaseConnectionSummary{Id: "TEST_CONNECTION_ID", ChannelCount: 1}, + channelAllocFunc: &BaseChannelSummary{Id: "TEST_CHANNEL_ID", ConnectionRecordingId: "TEST_CONNECTION_RECORDING_ID"}, + expectedError: "bsr.OpenSession: failed to get summary type", + wantSessionErr: true, + }, + { + name: "without-connection-allocFunc", + protocol: Protocol("TEST_BSR_OPEN_CONNECTION_PROTOCOL"), + sId: 45678, + sessionAllocFunc: &BaseSessionSummary{Id: "TEST_SESSION_ID", ConnectionCount: 1}, + connectionAllocFunc: nil, + channelAllocFunc: &BaseChannelSummary{Id: "TEST_CHANNEL_ID", ConnectionRecordingId: "TEST_CONNECTION_RECORDING_ID"}, + expectedError: "bsr.(Session).OpenConnection: failed to get summary type", + wantConnErr: true, + }, + { + name: "without-channel-allocFunc", + protocol: Protocol("TEST_BSR_OPEN_CHANNEL_PROTOCOL"), + sId: 23588, + sessionAllocFunc: &BaseSessionSummary{Id: "TEST_SESSION_ID", ConnectionCount: 1}, + connectionAllocFunc: &BaseConnectionSummary{Id: "TEST_CONNECTION_ID", ChannelCount: 1}, + expectedError: "bsr.OpenChannel: failed to get summary type", + wantChanErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if tc.sessionAllocFunc != nil { + err := RegisterSummaryAllocFunc(tc.protocol, SessionContainer, func(ctx context.Context) Summary { + return tc.sessionAllocFunc + }) + require.NoError(t, err) + } + if tc.connectionAllocFunc != nil { + err := RegisterSummaryAllocFunc(tc.protocol, ConnectionContainer, func(ctx context.Context) Summary { + return tc.connectionAllocFunc + }) + require.NoError(t, err) + } + if tc.channelAllocFunc != nil { + err := RegisterSummaryAllocFunc(tc.protocol, ChannelContainer, func(ctx context.Context) Summary { + return tc.channelAllocFunc + }) + require.NoError(t, err) + } + + keys, err := kms.CreateKeys(ctx, kms.TestWrapper(t), "session") + require.NoError(t, err) + + sessionId := fmt.Sprintf("s_%v", tc.sId) + srm := &SessionRecordingMeta{ + Id: fmt.Sprintf("sr_%v", tc.sId), + Protocol: tc.protocol, + } + sessionMeta := TestSessionMeta(sessionId) + + sesh, err := NewSession(ctx, srm, sessionMeta, f, keys, WithSupportsMultiplex(true)) + require.NoError(t, err) + require.NotNil(t, sesh) + + connectionId := "connection" + connMeta := &ConnectionRecordingMeta{Id: connectionId} + conn, err := sesh.NewConnection(ctx, connMeta) + require.NoError(t, err) + require.NotNil(t, conn) + + channelId := "channel" + chanMeta := &ChannelRecordingMeta{ + Id: channelId, + Type: "chan", + } + ch, err := conn.NewChannel(ctx, chanMeta) + require.NoError(t, err) + require.NotNil(t, ch) + + sesh.EncodeSummary(ctx, tc.sessionAllocFunc) + conn.EncodeSummary(ctx, tc.connectionAllocFunc) + ch.EncodeSummary(ctx, tc.channelAllocFunc) + + ch.Close(ctx) + conn.Close(ctx) + sesh.Close(ctx) + + keyFn := func(w kms.WrappedKeys) (kms.UnwrappedKeys, error) { + u := kms.UnwrappedKeys{ + BsrKey: keys.BsrKey, + PrivKey: keys.PrivKey, + } + return u, nil + } + + opSesh, err := OpenSession(ctx, srm.Id, f, keyFn) + if tc.wantSessionErr { + require.Error(t, err) + require.ErrorContains(t, err, tc.expectedError) + return + } + require.NoError(t, err) + require.NotNil(t, opSesh) + + opConn, err := opSesh.OpenConnection(ctx, connectionId) + if tc.wantConnErr { + require.Error(t, err) + require.ErrorContains(t, err, tc.expectedError) + return + } + require.NoError(t, err) + require.NotNil(t, opConn) + + opChan, err := opConn.OpenChannel(ctx, channelId) + if tc.wantChanErr { + require.Error(t, err) + require.ErrorContains(t, err, tc.expectedError) + return + } + require.NoError(t, err) + require.NotNil(t, opChan) + }) + } +} diff --git a/internal/bsr/container.go b/internal/bsr/container.go index d9c1fd3ade..4fd7a68f6d 100644 --- a/internal/bsr/container.go +++ b/internal/bsr/container.go @@ -31,13 +31,13 @@ const ( ) // ContainerType defines the type of container. -type containerType string +type ContainerType string // Valid container types. const ( - sessionContainer containerType = "session" - connectionContainer containerType = "connection" - channelContainer containerType = "channel" + SessionContainer ContainerType = "session" + ConnectionContainer ContainerType = "connection" + ChannelContainer ContainerType = "channel" ) // container contains a group of files in a BSR. @@ -64,7 +64,7 @@ type container struct { } // newContainer creates a container for the given type backed by the provide storage.Container. -func newContainer(ctx context.Context, t containerType, c storage.Container, keys *kms.Keys) (*container, error) { +func newContainer(ctx context.Context, t ContainerType, c storage.Container, keys *kms.Keys) (*container, error) { j, err := c.OpenFile(ctx, journalFileName, storage.WithCreateFile(), storage.WithFileAccessMode(storage.WriteOnly), @@ -124,7 +124,7 @@ func newContainer(ctx context.Context, t containerType, c storage.Container, key type populateKeyFunc func(c *container) (*kms.Keys, error) // openContainer will set keys and load and verify the checksums for this container -func openContainer(ctx context.Context, t containerType, c storage.Container, keyGetFunc populateKeyFunc) (*container, error) { +func openContainer(ctx context.Context, t ContainerType, c storage.Container, keyGetFunc populateKeyFunc) (*container, error) { const op = "bsr.openContainer" switch { case t == "": diff --git a/internal/bsr/container_test.go b/internal/bsr/container_test.go index 6353425092..99076082ee 100644 --- a/internal/bsr/container_test.go +++ b/internal/bsr/container_test.go @@ -23,7 +23,7 @@ func TestSyncBsrKeys(t *testing.T) { fc, err := f.New(ctx, fmt.Sprintf(bsrFileNameTemplate, "session-id")) require.NoError(t, err) - c, err := newContainer(ctx, sessionContainer, fc, keys) + c, err := newContainer(ctx, SessionContainer, fc, keys) require.NoError(t, err) require.NotNil(t, c) diff --git a/internal/bsr/errors.go b/internal/bsr/errors.go index 795931d44d..143fe0cc9f 100644 --- a/internal/bsr/errors.go +++ b/internal/bsr/errors.go @@ -27,8 +27,8 @@ var ( // particular protocol. ErrNotSupported = errors.New("not supported by protocol") - // ErrAlreadyRegistered is an error with registering chunk decoder functions. - ErrAlreadyRegistered = errors.New("chunk type already registered") + // ErrAlreadyRegistered is an error with registering functions. + ErrAlreadyRegistered = errors.New("type already registered") // ErrEndChunkNotEmpty indicates a malformed END chunk. ErrEndChunkNotEmpty = errors.New("end chunk not empty") diff --git a/internal/bsr/ssh/types.go b/internal/bsr/ssh/types.go index fc9adf0322..326b5b980d 100644 --- a/internal/bsr/ssh/types.go +++ b/internal/bsr/ssh/types.go @@ -4,10 +4,27 @@ package ssh import ( + "context" + "time" + "github.com/hashicorp/boundary/internal/bsr" "golang.org/x/crypto/ssh" ) +func init() { + if err := bsr.RegisterSummaryAllocFunc(Protocol, bsr.ChannelContainer, allocChannelSummary); err != nil { + panic(err) + } + + if err := bsr.RegisterSummaryAllocFunc(Protocol, bsr.SessionContainer, bsr.AllocSessionSummary); err != nil { + panic(err) + } + + if err := bsr.RegisterSummaryAllocFunc(Protocol, bsr.ConnectionContainer, bsr.AllocConnectionSummary); err != nil { + panic(err) + } +} + // SessionProgram identifies the program running on this channel // as outlined in https://www.rfc-editor.org/rfc/rfc4254.html#section-6.5 : // @@ -85,10 +102,49 @@ type OpenChannelError ssh.OpenChannelError // // OpenFailure will be nil if the Channel was successfully opened. type ChannelSummary struct { - ChannelSummary *bsr.ChannelSummary + ChannelSummary *bsr.BaseChannelSummary SessionProgram SessionProgram SubsystemName string ExecProgram ExecApplicationProgram FileTransferDirection FileTransferDirection OpenFailure *OpenChannelError `json:",omitempty"` } + +// GetId returns the Id of the container. +func (c *ChannelSummary) GetId() string { + return c.ChannelSummary.Id +} + +// GetId returns the Id of the container. +func (c *ChannelSummary) GetConnectionRecordingId() string { + return c.ChannelSummary.ConnectionRecordingId +} + +// GetStartTime returns the start time using a monotonic clock. +func (c *ChannelSummary) GetStartTime() time.Time { + return c.ChannelSummary.StartTime +} + +// GetEndTime returns the end time using a monotonic clock. +func (c *ChannelSummary) GetEndTime() time.Time { + return c.ChannelSummary.EndTime +} + +// GetBytesUp returns upload bytes. +func (c *ChannelSummary) GetBytesUp() uint64 { + return c.ChannelSummary.BytesUp +} + +// GetBytesDown returns download bytes. +func (c *ChannelSummary) GetBytesDown() uint64 { + return c.ChannelSummary.BytesDown +} + +// GetChannelType the type of summary channel. +func (c *ChannelSummary) GetChannelType() string { + return c.ChannelSummary.ChannelType +} + +func allocChannelSummary(_ context.Context) bsr.Summary { + return &ChannelSummary{} +} diff --git a/internal/bsr/summary.go b/internal/bsr/summary.go new file mode 100644 index 0000000000..cbfb82167d --- /dev/null +++ b/internal/bsr/summary.go @@ -0,0 +1,51 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package bsr + +import ( + "context" + "fmt" +) + +// SummaryAllocFunc is a function that returns a summary type +type SummaryAllocFunc func(ctx context.Context) Summary + +// summaryAllocFuncRegistry mappings of protocols and container type +// for each SummaryAllocFunc +type summaryAllocFuncRegistry map[Protocol]map[ContainerType]SummaryAllocFunc + +func (r summaryAllocFuncRegistry) get(p Protocol, c ContainerType) (SummaryAllocFunc, bool) { + protocol, ok := r[p] + if !ok { + return nil, false + } + af, ok := protocol[c] + return af, ok +} + +var summaryAllocFuncs summaryAllocFuncRegistry + +// RegisterSummaryAllocFunc registers a SummaryAllocFunc for the given Protocol. +// A given Protocol and Container can only have one SummaryAllocFunc function +// registered. +func RegisterSummaryAllocFunc(p Protocol, c ContainerType, af SummaryAllocFunc) error { + const op = "bsr.RegisterSummaryAllocFunc" + + if summaryAllocFuncs == nil { + summaryAllocFuncs = make(map[Protocol]map[ContainerType]SummaryAllocFunc) + } + + protocol, ok := summaryAllocFuncs[p] + if !ok { + protocol = make(map[ContainerType]SummaryAllocFunc) + } + + _, ok = protocol[c] + if ok { + return fmt.Errorf("%s: %s protocol with %s container: %w", op, p, c, ErrAlreadyRegistered) + } + protocol[c] = af + summaryAllocFuncs[p] = protocol + return nil +} diff --git a/internal/bsr/summary_test.go b/internal/bsr/summary_test.go new file mode 100644 index 0000000000..89ffca091f --- /dev/null +++ b/internal/bsr/summary_test.go @@ -0,0 +1,190 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package bsr + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRegisterSummaryAllocFunc_TestProtocol(t *testing.T) { + ctx := context.Background() + startTime := time.Now() + endTime := time.Now() + + cases := []struct { + name string + p Protocol + c ContainerType + cf SummaryAllocFunc + wantP Protocol + want *BaseSummary + wantRegisterErr error + wantGetAllocErr bool + }{ + { + "valid summary", + Protocol("TEST_PROTOCOL"), + ChannelContainer, + func(ctx context.Context) Summary { + return &BaseSummary{ + Id: "TEST_ID", + StartTime: startTime, + EndTime: endTime, + } + }, + Protocol("TEST_PROTOCOL"), + &BaseSummary{ + Id: "TEST_ID", + StartTime: startTime, + EndTime: endTime, + }, + nil, + false, + }, + { + "already-registered-container", + Protocol("TEST_PROTOCOL"), + ChannelContainer, + nil, + Protocol("TEST_PROTOCOL"), + &BaseSummary{}, + errors.New("bsr.RegisterSummaryAllocFunc: TEST_PROTOCOL protocol with channel container: type already registered"), + false, + }, + { + "invalid-protocol", + Protocol("TEST_PROTOCOL_2"), + ChannelContainer, + nil, + Protocol("TEST_INVALID_PROTOCOL"), + &BaseSummary{}, + nil, + true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := RegisterSummaryAllocFunc(tc.p, tc.c, tc.cf) + if tc.wantRegisterErr != nil { + assert.EqualError(t, tc.wantRegisterErr, err.Error()) + return + } + require.NoError(t, err) + + af, ok := summaryAllocFuncs.get(tc.wantP, tc.c) + if tc.wantGetAllocErr { + require.False(t, ok, "found invalid summary") + return + } + require.True(t, ok, "could not get summary") + + got := af(ctx) + + assert.Equal(t, tc.want.GetId(), got.GetId()) + assert.Equal(t, tc.want.GetStartTime(), got.GetStartTime()) + assert.Equal(t, tc.want.GetEndTime(), got.GetEndTime()) + }) + } +} + +func TestRegisterSummaryAllocFunc_TestChannel(t *testing.T) { + ctx := context.Background() + + protocol := Protocol("TEST_CHANNEL_PROTOCOL") + chs := &BaseChannelSummary{ + Id: "TEST_ID", + ConnectionRecordingId: "TEST_CONNECTION_RECORDING_ID", + ChannelType: "CONTAINER", + StartTime: time.Now(), + EndTime: time.Now(), + BytesUp: 100, + BytesDown: 200, + } + + err := RegisterSummaryAllocFunc(protocol, ChannelContainer, func(ctx context.Context) Summary { + return chs + }) + require.NoError(t, err) + + af, ok := summaryAllocFuncs.get(protocol, ChannelContainer) + require.True(t, ok, "could not get channel summary") + + cf := af(ctx) + got := cf.(*BaseChannelSummary) + + assert.Equal(t, chs.GetId(), got.GetId()) + assert.Equal(t, chs.GetConnectionRecordingId(), got.GetConnectionRecordingId()) + assert.Equal(t, chs.GetChannelType(), got.GetChannelType()) + assert.Equal(t, chs.GetStartTime(), got.GetStartTime()) + assert.Equal(t, chs.GetEndTime(), got.GetEndTime()) + assert.Equal(t, chs.GetBytesUp(), got.GetBytesUp()) + assert.Equal(t, chs.GetBytesDown(), got.GetBytesDown()) +} + +func TestRegisterSummaryAllocFunc_TestConnection(t *testing.T) { + ctx := context.Background() + + protocol := Protocol("TEST_CONNECTION_PROTOCOL") + chs := &BaseConnectionSummary{ + Id: "TEST_ID", + ChannelCount: 1, + StartTime: time.Now(), + EndTime: time.Now(), + BytesUp: 100, + BytesDown: 200, + } + + err := RegisterSummaryAllocFunc(protocol, ConnectionContainer, func(ctx context.Context) Summary { + return chs + }) + require.NoError(t, err) + + af, ok := summaryAllocFuncs.get(protocol, ConnectionContainer) + require.True(t, ok, "could not get connection summary") + + cf := af(ctx) + got := cf.(*BaseConnectionSummary) + + assert.Equal(t, chs.GetId(), got.GetId()) + assert.Equal(t, chs.GetChannelCount(), got.GetChannelCount()) + assert.Equal(t, chs.GetStartTime(), got.GetStartTime()) + assert.Equal(t, chs.GetEndTime(), got.GetEndTime()) + assert.Equal(t, chs.GetBytesUp(), got.GetBytesUp()) + assert.Equal(t, chs.GetBytesDown(), got.GetBytesDown()) +} + +func TestRegisterSummaryAllocFunc_TestSession(t *testing.T) { + ctx := context.Background() + + protocol := Protocol("TEST_SESSION_PROTOCOL") + chs := &BaseSessionSummary{ + Id: "TEST_ID", + ConnectionCount: 1, + StartTime: time.Now(), + EndTime: time.Now(), + } + + err := RegisterSummaryAllocFunc(protocol, SessionContainer, func(ctx context.Context) Summary { + return chs + }) + require.NoError(t, err) + + af, ok := summaryAllocFuncs.get(protocol, SessionContainer) + require.True(t, ok, "could not get session summary") + + cf := af(ctx) + got := cf.(*BaseSessionSummary) + + assert.Equal(t, chs.GetId(), got.GetId()) + assert.Equal(t, chs.GetConnectionCount(), got.GetConnectionCount()) + assert.Equal(t, chs.GetStartTime(), got.GetStartTime()) + assert.Equal(t, chs.GetEndTime(), got.GetEndTime()) +} diff --git a/internal/bsr/types.go b/internal/bsr/types.go index 6f318c061a..655e36d2e3 100644 --- a/internal/bsr/types.go +++ b/internal/bsr/types.go @@ -3,40 +3,202 @@ package bsr -import "time" - -type ( - // SessionSummary encapsulates data for a session, including its session id, connection count, - // and start/end time using a monotonic clock - SessionSummary struct { - Id string - ConnectionCount uint64 - StartTime time.Time - EndTime time.Time - Errors error - } - - // ConnectionSummary encapsulates data for a connection, including its connection id, channel count, - // start/end time using a monotonic clock, and the aggregate bytes up/ down of its channels - ConnectionSummary struct { - Id string - ChannelCount uint64 - StartTime time.Time - EndTime time.Time - BytesUp uint64 - BytesDown uint64 - Errors error - } - - // ChannelSummary encapsulates data for a channel, including its id, channel type, - // start/end time using a monotonic clock, and the bytes up/ down seen on this channel - ChannelSummary struct { - Id string - ConnectionRecordingId string - StartTime time.Time - EndTime time.Time - BytesUp uint64 - BytesDown uint64 - ChannelType string - } +import ( + "context" + "time" ) + +// BaseSummary contains the common fields of all summary types. +type BaseSummary struct { + Id string + StartTime time.Time + EndTime time.Time + Errors error +} + +// Summary contains statistics for a container +type Summary interface { + // GetId returns the Id of the container. + GetId() string + // GetStartTime returns the start time using a monotonic clock. + GetStartTime() time.Time + // GetEndTime returns the end time using a monotonic clock. + GetEndTime() time.Time +} + +// GetId returns the Id of the container. +func (b *BaseSummary) GetId() string { + return b.Id +} + +// GetStartTime returns the start time using a monotonic clock. +func (b *BaseSummary) GetStartTime() time.Time { + return b.StartTime +} + +// GetEndTime returns the end time using a monotonic clock. +func (b *BaseSummary) GetEndTime() time.Time { + return b.EndTime +} + +// BaseSessionSummary encapsulates data for a session, including its session id, connection count, +// and start/end time using a monotonic clock +type BaseSessionSummary struct { + Id string + ConnectionCount uint64 + StartTime time.Time + EndTime time.Time + Errors error +} + +// SessionSummary contains statistics for a session container +type SessionSummary interface { + Summary + // GetConnectionCount returns the connection count. + GetConnectionCount() uint64 +} + +func AllocSessionSummary(_ context.Context) Summary { + return &BaseSessionSummary{} +} + +// GetId returns the Id of the container. +func (b *BaseSessionSummary) GetId() string { + return b.Id +} + +// GetStartTime returns the start time using a monotonic clock. +func (b *BaseSessionSummary) GetStartTime() time.Time { + return b.StartTime +} + +// GetEndTime returns the end time using a monotonic clock. +func (b *BaseSessionSummary) GetEndTime() time.Time { + return b.EndTime +} + +// GetConnectionCount returns the connection count. +func (b *BaseSessionSummary) GetConnectionCount() uint64 { + return b.ConnectionCount +} + +// BaseChannelSummary encapsulates data for a channel, including its id, channel type, +// start/end time using a monotonic clock, and the bytes up/ down seen on this channel +type BaseChannelSummary struct { + Id string + ConnectionRecordingId string + StartTime time.Time + EndTime time.Time + BytesUp uint64 + BytesDown uint64 + ChannelType string +} + +// ChannelSummary contains statistics for a channel container +type ChannelSummary interface { + Summary + // GetConnectionRecordingId returns the connection recording id of the channel. + GetConnectionRecordingId() string + // GetBytesUp returns upload bytes. + GetBytesUp() uint64 + // BytesDown returns download bytes. + GetBytesDown() uint64 + // GetChannelType the type of summary channel. + GetChannelType() string +} + +func AllocChannelSummary(_ context.Context) Summary { + return &BaseChannelSummary{} +} + +// GetId returns the Id of the container. +func (b *BaseChannelSummary) GetId() string { + return b.Id +} + +// GetId returns the Id of the container. +func (b *BaseChannelSummary) GetConnectionRecordingId() string { + return b.ConnectionRecordingId +} + +// GetStartTime returns the start time using a monotonic clock. +func (b *BaseChannelSummary) GetStartTime() time.Time { + return b.StartTime +} + +// GetEndTime returns the end time using a monotonic clock. +func (b *BaseChannelSummary) GetEndTime() time.Time { + return b.EndTime +} + +// GetBytesUp returns upload bytes. +func (b *BaseChannelSummary) GetBytesUp() uint64 { + return b.BytesUp +} + +// GetBytesDown returns download bytes. +func (b *BaseChannelSummary) GetBytesDown() uint64 { + return b.BytesDown +} + +// GetChannelType returns the type of summary channel. +func (b *BaseChannelSummary) GetChannelType() string { + return b.ChannelType +} + +// BaseConnectionSummary encapsulates data for a connection, including its connection id, channel count, +// start/end time using a monotonic clock, and the aggregate bytes up/ down of its channels +type BaseConnectionSummary struct { + Id string + ChannelCount uint64 + StartTime time.Time + EndTime time.Time + BytesUp uint64 + BytesDown uint64 + Errors error +} + +// ConnectionSummary contains statistics for a connection container +type ConnectionSummary interface { + Summary + // GetChannelCount returns the channel count. + GetChannelCount() uint64 + // GetBytesUp returns upload bytes. + GetBytesUp() uint64 + // BytesDown returns download bytes. + GetBytesDown() uint64 +} + +func AllocConnectionSummary(_ context.Context) Summary { + return &BaseConnectionSummary{} +} + +// GetChannelCount returns the channel count. +func (b *BaseConnectionSummary) GetChannelCount() uint64 { + return b.ChannelCount +} + +// GetId returns the Id of the container. +func (b *BaseConnectionSummary) GetId() string { + return b.Id +} + +// GetStartTime returns the start time using a monotonic clock. +func (b *BaseConnectionSummary) GetStartTime() time.Time { + return b.StartTime +} + +// GetEndTime returns the end time using a monotonic clock. +func (b *BaseConnectionSummary) GetEndTime() time.Time { + return b.EndTime +} + +// GetBytesUp BaseConnectionSummary upload bytes. +func (b *BaseConnectionSummary) GetBytesUp() uint64 { + return b.BytesUp +} + +// GetBytesDown returns download bytes. +func (b *BaseConnectionSummary) GetBytesDown() uint64 { + return b.BytesDown +}