add ConnectionState type with unit tests

jimlambrt-session-basics
Jim Lambert 6 years ago
parent d6e7a7c546
commit e71b1ba75f

@ -0,0 +1,142 @@
package session
import (
"context"
"fmt"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"google.golang.org/protobuf/types/known/timestamppb"
)
const (
defaultConnectionStateTableName = "session_connection_state"
)
// ConnectionStatus of the connection's state
type ConnectionStatus string
const (
StatusConnected ConnectionStatus = "connected"
StatusClosed ConnectionStatus = "closed"
)
// String representation of the state's status
func (s ConnectionStatus) String() string {
return string(s)
}
// State of the session
type ConnectionState struct {
// ConnectionId is used to access the state via an API
ConnectionId string `json:"public_id,omitempty" gorm:"primary_key"`
// status of the connection
Status string `protobuf:"bytes,20,opt,name=status,proto3" json:"status,omitempty" gorm:"column:state"`
// PreviousEndTime from the RDBMS
PreviousEndTime *timestamp.Timestamp `json:"previous_end_time,omitempty" gorm:"default:current_timestamp"`
// StartTime from the RDBMS
StartTime *timestamp.Timestamp `json:"start_time,omitempty" gorm:"default:current_timestamp;primary_key"`
// EndTime from the RDBMS
EndTime *timestamp.Timestamp `json:"end_time,omitempty" gorm:"default:current_timestamp"`
tableName string `gorm:"-"`
}
var _ Cloneable = (*ConnectionState)(nil)
var _ db.VetForWriter = (*ConnectionState)(nil)
// NewConnectionState creates a new in memory connection state. No options
// are currently supported.
func NewConnectionState(connectionId string, state ConnectionStatus, opt ...Option) (*ConnectionState, error) {
s := ConnectionState{
ConnectionId: connectionId,
Status: state.String(),
}
if err := s.validate("new connection state:"); err != nil {
return nil, err
}
return &s, nil
}
// allocConnectionState will allocate a connection State
func allocConnectionState() ConnectionState {
return ConnectionState{}
}
// Clone creates a clone of the State
func (s *ConnectionState) Clone() interface{} {
clone := &ConnectionState{
ConnectionId: s.ConnectionId,
Status: s.Status,
}
if s.PreviousEndTime != nil {
clone.PreviousEndTime = &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
Seconds: s.PreviousEndTime.Timestamp.Seconds,
Nanos: s.PreviousEndTime.Timestamp.Nanos,
},
}
}
if s.StartTime != nil {
clone.StartTime = &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
Seconds: s.StartTime.Timestamp.Seconds,
Nanos: s.StartTime.Timestamp.Nanos,
},
}
}
if s.EndTime != nil {
clone.EndTime = &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
Seconds: s.EndTime.Timestamp.Seconds,
Nanos: s.EndTime.Timestamp.Nanos,
},
}
}
return clone
}
// VetForWrite implements db.VetForWrite() interface and validates the state
// before it's written.
func (s *ConnectionState) VetForWrite(ctx context.Context, r db.Reader, opType db.OpType, opt ...db.Option) error {
if err := s.validate("connection state vet for write:"); err != nil {
return err
}
return nil
}
// TableName returns the tablename to override the default gorm table name
func (s *ConnectionState) TableName() string {
if s.tableName != "" {
return s.tableName
}
return defaultConnectionStateTableName
}
// SetTableName sets the tablename and satisfies the ReplayableMessage
// interface. If the caller attempts to set the name to "" the name will be
// reset to the default name.
func (s *ConnectionState) SetTableName(n string) {
s.tableName = n
}
// validate checks the session state
func (s *ConnectionState) validate(errorPrefix string) error {
if s.Status == "" {
return fmt.Errorf("%s missing status: %w", errorPrefix, db.ErrInvalidParameter)
}
if s.ConnectionId == "" {
return fmt.Errorf("%s missing connection id: %w", errorPrefix, db.ErrInvalidParameter)
}
if s.StartTime != nil {
return fmt.Errorf("%s start time is not settable: %w", errorPrefix, db.ErrInvalidParameter)
}
if s.EndTime != nil {
return fmt.Errorf("%s end time is not settable: %w", errorPrefix, db.ErrInvalidParameter)
}
if s.PreviousEndTime != nil {
return fmt.Errorf("%s previous end time is not settable: %w", errorPrefix, db.ErrInvalidParameter)
}
return nil
}

@ -0,0 +1,212 @@
package session
import (
"context"
"errors"
"testing"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/iam"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConnectionState_Create(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
session := TestDefaultSession(t, conn, wrapper, iamRepo)
connection := TestConnection(t, conn, session.PublicId, "127.0.0.1", 443, "127.0.0.1", 4443)
type args struct {
connectionId string
status ConnectionStatus
}
tests := []struct {
name string
args args
want *ConnectionState
wantErr bool
wantIsErr error
create bool
wantCreateErr bool
}{
{
name: "valid",
args: args{
connectionId: connection.PublicId,
status: StatusClosed,
},
want: &ConnectionState{
ConnectionId: connection.PublicId,
Status: StatusClosed.String(),
},
create: true,
},
{
name: "empty-connectionId",
args: args{
status: StatusClosed,
},
wantErr: true,
wantIsErr: db.ErrInvalidParameter,
},
{
name: "empty-status",
args: args{
connectionId: connection.PublicId,
},
wantErr: true,
wantIsErr: db.ErrInvalidParameter,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
got, err := NewConnectionState(tt.args.connectionId, tt.args.status)
if tt.wantErr {
require.Error(err)
assert.True(errors.Is(err, tt.wantIsErr))
return
}
require.NoError(err)
assert.Equal(tt.want, got)
if tt.create {
err = db.New(conn).Create(context.Background(), got)
if tt.wantCreateErr {
assert.Error(err)
return
} else {
assert.NoError(err)
}
}
})
}
}
func TestConnectionState_Delete(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
s := TestDefaultSession(t, conn, wrapper, iamRepo)
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
c2 := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
tests := []struct {
name string
state *ConnectionState
deleteConnectionStateId string
wantRowsDeleted int
wantErr bool
wantErrMsg string
}{
{
name: "valid",
state: TestConnectionState(t, conn, c.PublicId, StatusClosed),
wantErr: false,
wantRowsDeleted: 1,
},
{
name: "bad-id",
state: TestConnectionState(t, conn, c2.PublicId, StatusClosed),
deleteConnectionStateId: func() string {
id, err := db.NewPublicId(ConnectionStatePrefix)
require.NoError(t, err)
return id
}(),
wantErr: false,
wantRowsDeleted: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
var initialState ConnectionState
err := rw.LookupWhere(context.Background(), &initialState, "connection_id = ? and state = ?", tt.state.ConnectionId, tt.state.Status)
require.NoError(err)
deleteState := allocConnectionState()
if tt.deleteConnectionStateId != "" {
deleteState.ConnectionId = tt.deleteConnectionStateId
} else {
deleteState.ConnectionId = tt.state.ConnectionId
}
deleteState.StartTime = initialState.StartTime
deletedRows, err := rw.Delete(context.Background(), &deleteState)
if tt.wantErr {
require.Error(err)
return
}
require.NoError(err)
if tt.wantRowsDeleted == 0 {
assert.Equal(tt.wantRowsDeleted, deletedRows)
return
}
assert.Equal(tt.wantRowsDeleted, deletedRows)
foundState := allocConnectionState()
err = rw.LookupWhere(context.Background(), &foundState, "connection_id = ? and start_time = ?", tt.state.ConnectionId, initialState.StartTime)
require.Error(err)
assert.True(errors.Is(db.ErrRecordNotFound, err))
})
}
}
func TestConnectionState_Clone(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
t.Run("valid", func(t *testing.T) {
assert := assert.New(t)
s := TestDefaultSession(t, conn, wrapper, iamRepo)
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
state := TestConnectionState(t, conn, c.PublicId, StatusConnected)
cp := state.Clone()
assert.Equal(cp.(*ConnectionState), state)
})
t.Run("not-equal", func(t *testing.T) {
assert := assert.New(t)
s := TestDefaultSession(t, conn, wrapper, iamRepo)
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
state := TestConnectionState(t, conn, c.PublicId, StatusConnected)
state2 := TestConnectionState(t, conn, c.PublicId, StatusConnected)
cp := state.Clone()
assert.NotEqual(cp.(*ConnectionState), state2)
})
}
func TestConnectionState_SetTableName(t *testing.T) {
t.Parallel()
defaultTableName := defaultConnectionStateTableName
tests := []struct {
name string
setNameTo string
want string
}{
{
name: "new-name",
setNameTo: "new-name",
want: "new-name",
},
{
name: "reset to default",
setNameTo: "",
want: defaultTableName,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
def := allocConnectionState()
require.Equal(defaultTableName, def.TableName())
s := allocConnectionState()
s.SetTableName(tt.setNameTo)
assert.Equal(tt.want, s.TableName())
})
}
}
Loading…
Cancel
Save