mirror of https://github.com/hashicorp/boundary
parent
d6e7a7c546
commit
e71b1ba75f
@ -0,0 +1,142 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/db"
|
||||
"github.com/hashicorp/boundary/internal/db/timestamp"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultConnectionStateTableName = "session_connection_state"
|
||||
)
|
||||
|
||||
// ConnectionStatus of the connection's state
|
||||
type ConnectionStatus string
|
||||
|
||||
const (
|
||||
StatusConnected ConnectionStatus = "connected"
|
||||
StatusClosed ConnectionStatus = "closed"
|
||||
)
|
||||
|
||||
// String representation of the state's status
|
||||
func (s ConnectionStatus) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// State of the session
|
||||
type ConnectionState struct {
|
||||
// ConnectionId is used to access the state via an API
|
||||
ConnectionId string `json:"public_id,omitempty" gorm:"primary_key"`
|
||||
// status of the connection
|
||||
Status string `protobuf:"bytes,20,opt,name=status,proto3" json:"status,omitempty" gorm:"column:state"`
|
||||
// PreviousEndTime from the RDBMS
|
||||
PreviousEndTime *timestamp.Timestamp `json:"previous_end_time,omitempty" gorm:"default:current_timestamp"`
|
||||
// StartTime from the RDBMS
|
||||
StartTime *timestamp.Timestamp `json:"start_time,omitempty" gorm:"default:current_timestamp;primary_key"`
|
||||
// EndTime from the RDBMS
|
||||
EndTime *timestamp.Timestamp `json:"end_time,omitempty" gorm:"default:current_timestamp"`
|
||||
|
||||
tableName string `gorm:"-"`
|
||||
}
|
||||
|
||||
var _ Cloneable = (*ConnectionState)(nil)
|
||||
var _ db.VetForWriter = (*ConnectionState)(nil)
|
||||
|
||||
// NewConnectionState creates a new in memory connection state. No options
|
||||
// are currently supported.
|
||||
func NewConnectionState(connectionId string, state ConnectionStatus, opt ...Option) (*ConnectionState, error) {
|
||||
s := ConnectionState{
|
||||
ConnectionId: connectionId,
|
||||
Status: state.String(),
|
||||
}
|
||||
if err := s.validate("new connection state:"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
// allocConnectionState will allocate a connection State
|
||||
func allocConnectionState() ConnectionState {
|
||||
return ConnectionState{}
|
||||
}
|
||||
|
||||
// Clone creates a clone of the State
|
||||
func (s *ConnectionState) Clone() interface{} {
|
||||
clone := &ConnectionState{
|
||||
ConnectionId: s.ConnectionId,
|
||||
Status: s.Status,
|
||||
}
|
||||
if s.PreviousEndTime != nil {
|
||||
clone.PreviousEndTime = ×tamp.Timestamp{
|
||||
Timestamp: ×tamppb.Timestamp{
|
||||
Seconds: s.PreviousEndTime.Timestamp.Seconds,
|
||||
Nanos: s.PreviousEndTime.Timestamp.Nanos,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if s.StartTime != nil {
|
||||
clone.StartTime = ×tamp.Timestamp{
|
||||
Timestamp: ×tamppb.Timestamp{
|
||||
Seconds: s.StartTime.Timestamp.Seconds,
|
||||
Nanos: s.StartTime.Timestamp.Nanos,
|
||||
},
|
||||
}
|
||||
}
|
||||
if s.EndTime != nil {
|
||||
clone.EndTime = ×tamp.Timestamp{
|
||||
Timestamp: ×tamppb.Timestamp{
|
||||
Seconds: s.EndTime.Timestamp.Seconds,
|
||||
Nanos: s.EndTime.Timestamp.Nanos,
|
||||
},
|
||||
}
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
// VetForWrite implements db.VetForWrite() interface and validates the state
|
||||
// before it's written.
|
||||
func (s *ConnectionState) VetForWrite(ctx context.Context, r db.Reader, opType db.OpType, opt ...db.Option) error {
|
||||
if err := s.validate("connection state vet for write:"); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableName returns the tablename to override the default gorm table name
|
||||
func (s *ConnectionState) TableName() string {
|
||||
if s.tableName != "" {
|
||||
return s.tableName
|
||||
}
|
||||
return defaultConnectionStateTableName
|
||||
}
|
||||
|
||||
// SetTableName sets the tablename and satisfies the ReplayableMessage
|
||||
// interface. If the caller attempts to set the name to "" the name will be
|
||||
// reset to the default name.
|
||||
func (s *ConnectionState) SetTableName(n string) {
|
||||
s.tableName = n
|
||||
}
|
||||
|
||||
// validate checks the session state
|
||||
func (s *ConnectionState) validate(errorPrefix string) error {
|
||||
if s.Status == "" {
|
||||
return fmt.Errorf("%s missing status: %w", errorPrefix, db.ErrInvalidParameter)
|
||||
}
|
||||
if s.ConnectionId == "" {
|
||||
return fmt.Errorf("%s missing connection id: %w", errorPrefix, db.ErrInvalidParameter)
|
||||
}
|
||||
if s.StartTime != nil {
|
||||
return fmt.Errorf("%s start time is not settable: %w", errorPrefix, db.ErrInvalidParameter)
|
||||
}
|
||||
if s.EndTime != nil {
|
||||
return fmt.Errorf("%s end time is not settable: %w", errorPrefix, db.ErrInvalidParameter)
|
||||
}
|
||||
if s.PreviousEndTime != nil {
|
||||
return fmt.Errorf("%s previous end time is not settable: %w", errorPrefix, db.ErrInvalidParameter)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,212 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/db"
|
||||
"github.com/hashicorp/boundary/internal/iam"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnectionState_Create(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
session := TestDefaultSession(t, conn, wrapper, iamRepo)
|
||||
connection := TestConnection(t, conn, session.PublicId, "127.0.0.1", 443, "127.0.0.1", 4443)
|
||||
|
||||
type args struct {
|
||||
connectionId string
|
||||
status ConnectionStatus
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *ConnectionState
|
||||
wantErr bool
|
||||
wantIsErr error
|
||||
create bool
|
||||
wantCreateErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
connectionId: connection.PublicId,
|
||||
status: StatusClosed,
|
||||
},
|
||||
want: &ConnectionState{
|
||||
ConnectionId: connection.PublicId,
|
||||
Status: StatusClosed.String(),
|
||||
},
|
||||
create: true,
|
||||
},
|
||||
{
|
||||
name: "empty-connectionId",
|
||||
args: args{
|
||||
status: StatusClosed,
|
||||
},
|
||||
wantErr: true,
|
||||
wantIsErr: db.ErrInvalidParameter,
|
||||
},
|
||||
{
|
||||
name: "empty-status",
|
||||
args: args{
|
||||
connectionId: connection.PublicId,
|
||||
},
|
||||
wantErr: true,
|
||||
wantIsErr: db.ErrInvalidParameter,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
got, err := NewConnectionState(tt.args.connectionId, tt.args.status)
|
||||
if tt.wantErr {
|
||||
require.Error(err)
|
||||
assert.True(errors.Is(err, tt.wantIsErr))
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.Equal(tt.want, got)
|
||||
if tt.create {
|
||||
err = db.New(conn).Create(context.Background(), got)
|
||||
if tt.wantCreateErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionState_Delete(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
rw := db.New(conn)
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
s := TestDefaultSession(t, conn, wrapper, iamRepo)
|
||||
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
|
||||
c2 := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
state *ConnectionState
|
||||
deleteConnectionStateId string
|
||||
wantRowsDeleted int
|
||||
wantErr bool
|
||||
wantErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
state: TestConnectionState(t, conn, c.PublicId, StatusClosed),
|
||||
wantErr: false,
|
||||
wantRowsDeleted: 1,
|
||||
},
|
||||
{
|
||||
name: "bad-id",
|
||||
state: TestConnectionState(t, conn, c2.PublicId, StatusClosed),
|
||||
deleteConnectionStateId: func() string {
|
||||
id, err := db.NewPublicId(ConnectionStatePrefix)
|
||||
require.NoError(t, err)
|
||||
return id
|
||||
}(),
|
||||
wantErr: false,
|
||||
wantRowsDeleted: 0,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
|
||||
var initialState ConnectionState
|
||||
err := rw.LookupWhere(context.Background(), &initialState, "connection_id = ? and state = ?", tt.state.ConnectionId, tt.state.Status)
|
||||
require.NoError(err)
|
||||
|
||||
deleteState := allocConnectionState()
|
||||
if tt.deleteConnectionStateId != "" {
|
||||
deleteState.ConnectionId = tt.deleteConnectionStateId
|
||||
} else {
|
||||
deleteState.ConnectionId = tt.state.ConnectionId
|
||||
}
|
||||
deleteState.StartTime = initialState.StartTime
|
||||
deletedRows, err := rw.Delete(context.Background(), &deleteState)
|
||||
if tt.wantErr {
|
||||
require.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
if tt.wantRowsDeleted == 0 {
|
||||
assert.Equal(tt.wantRowsDeleted, deletedRows)
|
||||
return
|
||||
}
|
||||
assert.Equal(tt.wantRowsDeleted, deletedRows)
|
||||
foundState := allocConnectionState()
|
||||
err = rw.LookupWhere(context.Background(), &foundState, "connection_id = ? and start_time = ?", tt.state.ConnectionId, initialState.StartTime)
|
||||
require.Error(err)
|
||||
assert.True(errors.Is(db.ErrRecordNotFound, err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionState_Clone(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := TestDefaultSession(t, conn, wrapper, iamRepo)
|
||||
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
|
||||
state := TestConnectionState(t, conn, c.PublicId, StatusConnected)
|
||||
cp := state.Clone()
|
||||
assert.Equal(cp.(*ConnectionState), state)
|
||||
})
|
||||
t.Run("not-equal", func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := TestDefaultSession(t, conn, wrapper, iamRepo)
|
||||
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
|
||||
state := TestConnectionState(t, conn, c.PublicId, StatusConnected)
|
||||
state2 := TestConnectionState(t, conn, c.PublicId, StatusConnected)
|
||||
|
||||
cp := state.Clone()
|
||||
assert.NotEqual(cp.(*ConnectionState), state2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectionState_SetTableName(t *testing.T) {
|
||||
t.Parallel()
|
||||
defaultTableName := defaultConnectionStateTableName
|
||||
tests := []struct {
|
||||
name string
|
||||
setNameTo string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "new-name",
|
||||
setNameTo: "new-name",
|
||||
want: "new-name",
|
||||
},
|
||||
{
|
||||
name: "reset to default",
|
||||
setNameTo: "",
|
||||
want: defaultTableName,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
def := allocConnectionState()
|
||||
require.Equal(defaultTableName, def.TableName())
|
||||
s := allocConnectionState()
|
||||
s.SetTableName(tt.setNameTo)
|
||||
assert.Equal(tt.want, s.TableName())
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in new issue