diff --git a/internal/session/connection_state.go b/internal/session/connection_state.go new file mode 100644 index 0000000000..4110b4149c --- /dev/null +++ b/internal/session/connection_state.go @@ -0,0 +1,142 @@ +package session + +import ( + "context" + "fmt" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/timestamp" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ( + defaultConnectionStateTableName = "session_connection_state" +) + +// ConnectionStatus of the connection's state +type ConnectionStatus string + +const ( + StatusConnected ConnectionStatus = "connected" + StatusClosed ConnectionStatus = "closed" +) + +// String representation of the state's status +func (s ConnectionStatus) String() string { + return string(s) +} + +// State of the session +type ConnectionState struct { + // ConnectionId is used to access the state via an API + ConnectionId string `json:"public_id,omitempty" gorm:"primary_key"` + // status of the connection + Status string `protobuf:"bytes,20,opt,name=status,proto3" json:"status,omitempty" gorm:"column:state"` + // PreviousEndTime from the RDBMS + PreviousEndTime *timestamp.Timestamp `json:"previous_end_time,omitempty" gorm:"default:current_timestamp"` + // StartTime from the RDBMS + StartTime *timestamp.Timestamp `json:"start_time,omitempty" gorm:"default:current_timestamp;primary_key"` + // EndTime from the RDBMS + EndTime *timestamp.Timestamp `json:"end_time,omitempty" gorm:"default:current_timestamp"` + + tableName string `gorm:"-"` +} + +var _ Cloneable = (*ConnectionState)(nil) +var _ db.VetForWriter = (*ConnectionState)(nil) + +// NewConnectionState creates a new in memory connection state. No options +// are currently supported. +func NewConnectionState(connectionId string, state ConnectionStatus, opt ...Option) (*ConnectionState, error) { + s := ConnectionState{ + ConnectionId: connectionId, + Status: state.String(), + } + if err := s.validate("new connection state:"); err != nil { + return nil, err + } + return &s, nil +} + +// allocConnectionState will allocate a connection State +func allocConnectionState() ConnectionState { + return ConnectionState{} +} + +// Clone creates a clone of the State +func (s *ConnectionState) Clone() interface{} { + clone := &ConnectionState{ + ConnectionId: s.ConnectionId, + Status: s.Status, + } + if s.PreviousEndTime != nil { + clone.PreviousEndTime = ×tamp.Timestamp{ + Timestamp: ×tamppb.Timestamp{ + Seconds: s.PreviousEndTime.Timestamp.Seconds, + Nanos: s.PreviousEndTime.Timestamp.Nanos, + }, + } + } + + if s.StartTime != nil { + clone.StartTime = ×tamp.Timestamp{ + Timestamp: ×tamppb.Timestamp{ + Seconds: s.StartTime.Timestamp.Seconds, + Nanos: s.StartTime.Timestamp.Nanos, + }, + } + } + if s.EndTime != nil { + clone.EndTime = ×tamp.Timestamp{ + Timestamp: ×tamppb.Timestamp{ + Seconds: s.EndTime.Timestamp.Seconds, + Nanos: s.EndTime.Timestamp.Nanos, + }, + } + } + return clone +} + +// VetForWrite implements db.VetForWrite() interface and validates the state +// before it's written. +func (s *ConnectionState) VetForWrite(ctx context.Context, r db.Reader, opType db.OpType, opt ...db.Option) error { + if err := s.validate("connection state vet for write:"); err != nil { + return err + } + return nil +} + +// TableName returns the tablename to override the default gorm table name +func (s *ConnectionState) TableName() string { + if s.tableName != "" { + return s.tableName + } + return defaultConnectionStateTableName +} + +// SetTableName sets the tablename and satisfies the ReplayableMessage +// interface. If the caller attempts to set the name to "" the name will be +// reset to the default name. +func (s *ConnectionState) SetTableName(n string) { + s.tableName = n +} + +// validate checks the session state +func (s *ConnectionState) validate(errorPrefix string) error { + if s.Status == "" { + return fmt.Errorf("%s missing status: %w", errorPrefix, db.ErrInvalidParameter) + } + if s.ConnectionId == "" { + return fmt.Errorf("%s missing connection id: %w", errorPrefix, db.ErrInvalidParameter) + } + if s.StartTime != nil { + return fmt.Errorf("%s start time is not settable: %w", errorPrefix, db.ErrInvalidParameter) + } + if s.EndTime != nil { + return fmt.Errorf("%s end time is not settable: %w", errorPrefix, db.ErrInvalidParameter) + } + if s.PreviousEndTime != nil { + return fmt.Errorf("%s previous end time is not settable: %w", errorPrefix, db.ErrInvalidParameter) + } + return nil +} diff --git a/internal/session/connection_state_test.go b/internal/session/connection_state_test.go new file mode 100644 index 0000000000..eeabeb8713 --- /dev/null +++ b/internal/session/connection_state_test.go @@ -0,0 +1,212 @@ +package session + +import ( + "context" + "errors" + "testing" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/iam" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConnectionState_Create(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + session := TestDefaultSession(t, conn, wrapper, iamRepo) + connection := TestConnection(t, conn, session.PublicId, "127.0.0.1", 443, "127.0.0.1", 4443) + + type args struct { + connectionId string + status ConnectionStatus + } + tests := []struct { + name string + args args + want *ConnectionState + wantErr bool + wantIsErr error + create bool + wantCreateErr bool + }{ + { + name: "valid", + args: args{ + connectionId: connection.PublicId, + status: StatusClosed, + }, + want: &ConnectionState{ + ConnectionId: connection.PublicId, + Status: StatusClosed.String(), + }, + create: true, + }, + { + name: "empty-connectionId", + args: args{ + status: StatusClosed, + }, + wantErr: true, + wantIsErr: db.ErrInvalidParameter, + }, + { + name: "empty-status", + args: args{ + connectionId: connection.PublicId, + }, + wantErr: true, + wantIsErr: db.ErrInvalidParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + got, err := NewConnectionState(tt.args.connectionId, tt.args.status) + if tt.wantErr { + require.Error(err) + assert.True(errors.Is(err, tt.wantIsErr)) + return + } + require.NoError(err) + assert.Equal(tt.want, got) + if tt.create { + err = db.New(conn).Create(context.Background(), got) + if tt.wantCreateErr { + assert.Error(err) + return + } else { + assert.NoError(err) + } + } + }) + } +} + +func TestConnectionState_Delete(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + s := TestDefaultSession(t, conn, wrapper, iamRepo) + c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) + c2 := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) + + tests := []struct { + name string + state *ConnectionState + deleteConnectionStateId string + wantRowsDeleted int + wantErr bool + wantErrMsg string + }{ + { + name: "valid", + state: TestConnectionState(t, conn, c.PublicId, StatusClosed), + wantErr: false, + wantRowsDeleted: 1, + }, + { + name: "bad-id", + state: TestConnectionState(t, conn, c2.PublicId, StatusClosed), + deleteConnectionStateId: func() string { + id, err := db.NewPublicId(ConnectionStatePrefix) + require.NoError(t, err) + return id + }(), + wantErr: false, + wantRowsDeleted: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + var initialState ConnectionState + err := rw.LookupWhere(context.Background(), &initialState, "connection_id = ? and state = ?", tt.state.ConnectionId, tt.state.Status) + require.NoError(err) + + deleteState := allocConnectionState() + if tt.deleteConnectionStateId != "" { + deleteState.ConnectionId = tt.deleteConnectionStateId + } else { + deleteState.ConnectionId = tt.state.ConnectionId + } + deleteState.StartTime = initialState.StartTime + deletedRows, err := rw.Delete(context.Background(), &deleteState) + if tt.wantErr { + require.Error(err) + return + } + require.NoError(err) + if tt.wantRowsDeleted == 0 { + assert.Equal(tt.wantRowsDeleted, deletedRows) + return + } + assert.Equal(tt.wantRowsDeleted, deletedRows) + foundState := allocConnectionState() + err = rw.LookupWhere(context.Background(), &foundState, "connection_id = ? and start_time = ?", tt.state.ConnectionId, initialState.StartTime) + require.Error(err) + assert.True(errors.Is(db.ErrRecordNotFound, err)) + }) + } +} + +func TestConnectionState_Clone(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + t.Run("valid", func(t *testing.T) { + assert := assert.New(t) + s := TestDefaultSession(t, conn, wrapper, iamRepo) + c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) + state := TestConnectionState(t, conn, c.PublicId, StatusConnected) + cp := state.Clone() + assert.Equal(cp.(*ConnectionState), state) + }) + t.Run("not-equal", func(t *testing.T) { + assert := assert.New(t) + s := TestDefaultSession(t, conn, wrapper, iamRepo) + c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) + state := TestConnectionState(t, conn, c.PublicId, StatusConnected) + state2 := TestConnectionState(t, conn, c.PublicId, StatusConnected) + + cp := state.Clone() + assert.NotEqual(cp.(*ConnectionState), state2) + }) +} + +func TestConnectionState_SetTableName(t *testing.T) { + t.Parallel() + defaultTableName := defaultConnectionStateTableName + tests := []struct { + name string + setNameTo string + want string + }{ + { + name: "new-name", + setNameTo: "new-name", + want: "new-name", + }, + { + name: "reset to default", + setNameTo: "", + want: defaultTableName, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + def := allocConnectionState() + require.Equal(defaultTableName, def.TableName()) + s := allocConnectionState() + s.SetTableName(tt.setNameTo) + assert.Equal(tt.want, s.TableName()) + }) + } +}