add Connection type with unit tests

jimlambrt-session-basics
Jim Lambert 6 years ago
parent 9d43b47c65
commit 76eda01d66

@ -0,0 +1,179 @@
package session
import (
"context"
"fmt"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"google.golang.org/protobuf/types/known/timestamppb"
)
const (
defaultConnectionTableName = "session_connection"
)
// Session contains information about a user's session with a target
type Connection struct {
// PublicId is used to access the connection via an API
PublicId string `json:"public_id,omitempty" gorm:"primary_key"`
// SessionId of the connection
SessionId string `json:"session_id,omitempty" gorm:"default:null"`
// ClientAddress of the connection
ClientAddress string `json:"client_address,omitempty" gorm:"default:null"`
// ClientPort of the connection
ClientPort uint32 `json:"client_port,omitempty" gorm:"default:null"`
// BackendAddress of the connection
BackendAddress string `json:"backend_address,omitempty" gorm:"default:null"`
// BackendPort of the connection
BackendPort uint32 `json:"backend_port,omitempty" gorm:"default:null"`
// BytesUp of the connection
BytesUp uint64 `json:"bytes_up,omitempty" gorm:"default:null"`
// BytesDown of the connection
BytesDown uint64 `json:"bytes_down,omitempty" gorm:"default:null"`
// ClosedReason of the conneciont
ClosedReason string `json:"closed_reason,omitempty" gorm:"default:null"`
// CreateTime from the RDBMS
CreateTime *timestamp.Timestamp `json:"create_time,omitempty" gorm:"default:current_timestamp"`
// UpdateTime from the RDBMS
UpdateTime *timestamp.Timestamp `json:"update_time,omitempty" gorm:"default:current_timestamp"`
// Version of the connection
Version uint32 `json:"version,omitempty" gorm:"default:null"`
tableName string `gorm:"-"`
}
func (c *Connection) GetPublicId() string {
return c.PublicId
}
var _ Cloneable = (*Connection)(nil)
var _ db.VetForWriter = (*Connection)(nil)
// New creates a new in memory session. No options
// are currently supported.
func NewConnection(sessionID, clientAddress string, clientPort uint32, backendAddr string, backendPort uint32, opt ...Option) (*Connection, error) {
c := Connection{
SessionId: sessionID,
ClientAddress: clientAddress,
ClientPort: clientPort,
BackendAddress: backendAddr,
BackendPort: backendPort,
}
if err := c.validateNewConnection("new connection:"); err != nil {
return nil, err
}
return &c, nil
}
// AllocConnection will allocate a Session
func AllocConnection() Connection {
return Connection{}
}
// Clone creates a clone of the Session
func (c *Connection) Clone() interface{} {
clone := &Connection{
PublicId: c.PublicId,
SessionId: c.SessionId,
ClientAddress: c.ClientAddress,
ClientPort: c.ClientPort,
BackendAddress: c.BackendAddress,
BackendPort: c.BackendPort,
BytesUp: c.BytesUp,
BytesDown: c.BytesDown,
ClosedReason: c.ClosedReason,
Version: c.Version,
}
if c.CreateTime != nil {
clone.CreateTime = &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
Seconds: c.CreateTime.Timestamp.Seconds,
Nanos: c.CreateTime.Timestamp.Nanos,
},
}
}
if c.UpdateTime != nil {
clone.UpdateTime = &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
Seconds: c.UpdateTime.Timestamp.Seconds,
Nanos: c.UpdateTime.Timestamp.Nanos,
},
}
}
return clone
}
// VetForWrite implements db.VetForWrite() interface and validates the connection
// before it's written.
func (c *Connection) VetForWrite(ctx context.Context, r db.Reader, opType db.OpType, opt ...db.Option) error {
opts := db.GetOpts(opt...)
if c.PublicId == "" {
return fmt.Errorf("connection vet for write: missing public id: %w", db.ErrInvalidParameter)
}
switch opType {
case db.CreateOp:
if err := c.validateNewConnection("connection vet for write:"); err != nil {
return err
}
case db.UpdateOp:
switch {
case contains(opts.WithFieldMaskPaths, "PublicId"):
return fmt.Errorf("connection vet for write: public id is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "SessionId"):
return fmt.Errorf("connection vet for write: session id is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "ClientAddress"):
return fmt.Errorf("connection vet for write: client address is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "ClientPort"):
return fmt.Errorf("connection vet for write: client port is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "BackendAddress"):
return fmt.Errorf("connection vet for write: backend address is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "BackendPort"):
return fmt.Errorf("connection vet for write: backend port is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "CreateTime"):
return fmt.Errorf("connection vet for write: create time is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "UpdateTime"):
return fmt.Errorf("connection vet for write: update time is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "ClosedReason"):
if _, err := convertToClosedReason(c.ClosedReason); err != nil {
return fmt.Errorf("connection vet for write: %w", db.ErrInvalidParameter)
}
}
}
return nil
}
// TableName returns the tablename to override the default gorm table name
func (c *Connection) TableName() string {
if c.tableName != "" {
return c.tableName
}
return defaultConnectionTableName
}
// SetTableName sets the tablename and satisfies the ReplayableMessage
// interface. If the caller attempts to set the name to "" the name will be
// reset to the default name.
func (c *Connection) SetTableName(n string) {
c.tableName = n
}
// validateNewConnection checks everything but the connection's PublicId
func (c *Connection) validateNewConnection(errorPrefix string) error {
if c.SessionId == "" {
return fmt.Errorf("%s missing session id: %w", errorPrefix, db.ErrInvalidParameter)
}
if c.ClientAddress == "" {
return fmt.Errorf("%s missing client address: %w", errorPrefix, db.ErrInvalidParameter)
}
if c.ClientPort == 0 {
return fmt.Errorf("%s missing client port: %w", errorPrefix, db.ErrInvalidParameter)
}
if c.BackendAddress == "" {
return fmt.Errorf("%s missing backend address: %w", errorPrefix, db.ErrInvalidParameter)
}
if c.BackendPort == 0 {
return fmt.Errorf("%s missing backend port: %w", errorPrefix, db.ErrInvalidParameter)
}
return nil
}

@ -0,0 +1,255 @@
package session
import (
"context"
"errors"
"testing"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/iam"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConnection_Create(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
s := TestDefaultSession(t, conn, wrapper, iamRepo)
type args struct {
sessionId string
clientAddress string
clientPort uint32
backendAddress string
backendPort uint32
}
tests := []struct {
name string
args args
want *Connection
wantErr bool
wantIsErr error
create bool
wantCreateErr bool
}{
{
name: "valid",
args: args{
sessionId: s.PublicId,
clientAddress: "127.0.0.1",
clientPort: 22,
backendAddress: "127.0.0.1",
backendPort: 2222,
},
want: &Connection{
SessionId: s.PublicId,
ClientAddress: "127.0.0.1",
ClientPort: 22,
BackendAddress: "127.0.0.1",
BackendPort: 2222,
},
create: true,
},
{
name: "empty-session-id",
args: args{
clientAddress: "127.0.0.1",
clientPort: 22,
backendAddress: "127.0.0.1",
backendPort: 2222,
},
wantErr: true,
wantIsErr: db.ErrInvalidParameter,
},
{
name: "empty-client-address",
args: args{
sessionId: s.PublicId,
clientPort: 22,
backendAddress: "127.0.0.1",
backendPort: 2222,
},
wantErr: true,
wantIsErr: db.ErrInvalidParameter,
},
{
name: "empty-client-port",
args: args{
sessionId: s.PublicId,
clientAddress: "localhost",
backendAddress: "127.0.0.1",
backendPort: 2222,
},
wantErr: true,
wantIsErr: db.ErrInvalidParameter,
},
{
name: "empty-backend-address",
args: args{
sessionId: s.PublicId,
clientAddress: "localhost",
clientPort: 22,
backendPort: 2222,
},
wantErr: true,
wantIsErr: db.ErrInvalidParameter,
},
{
name: "empty-backend-port",
args: args{
sessionId: s.PublicId,
clientAddress: "localhost",
clientPort: 22,
backendAddress: "127.0.0.1",
},
wantErr: true,
wantIsErr: db.ErrInvalidParameter,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
got, err := NewConnection(
tt.args.sessionId,
tt.args.clientAddress,
tt.args.clientPort,
tt.args.backendAddress,
tt.args.backendPort,
)
if tt.wantErr {
require.Error(err)
assert.True(errors.Is(err, tt.wantIsErr))
return
}
require.NoError(err)
assert.Equal(tt.want, got)
if tt.create {
id, err := db.NewPublicId(ConnectionPrefix)
require.NoError(err)
got.PublicId = id
err = db.New(conn).Create(context.Background(), got)
if tt.wantCreateErr {
assert.Error(err)
return
} else {
assert.NoError(err)
}
}
})
}
}
func TestConnection_Delete(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
s := TestDefaultSession(t, conn, wrapper, iamRepo)
tests := []struct {
name string
connection *Connection
wantRowsDeleted int
wantErr bool
wantErrMsg string
}{
{
name: "valid",
connection: TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222),
wantErr: false,
wantRowsDeleted: 1,
},
{
name: "bad-id",
connection: func() *Connection {
c := AllocConnection()
id, err := db.NewPublicId(ConnectionPrefix)
require.NoError(t, err)
c.PublicId = id
return &c
}(),
wantErr: false,
wantRowsDeleted: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
deleteConnection := AllocConnection()
deleteConnection.PublicId = tt.connection.PublicId
deletedRows, err := rw.Delete(context.Background(), &deleteConnection)
if tt.wantErr {
require.Error(err)
return
}
require.NoError(err)
if tt.wantRowsDeleted == 0 {
assert.Equal(tt.wantRowsDeleted, deletedRows)
return
}
assert.Equal(tt.wantRowsDeleted, deletedRows)
foundConnection := AllocConnection()
foundConnection.PublicId = tt.connection.PublicId
err = rw.LookupById(context.Background(), &foundConnection)
require.Error(err)
assert.True(errors.Is(db.ErrRecordNotFound, err))
})
}
}
func TestConnection_Clone(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
t.Run("valid", func(t *testing.T) {
assert := assert.New(t)
s := TestDefaultSession(t, conn, wrapper, iamRepo)
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
cp := c.Clone()
assert.Equal(cp.(*Connection), c)
})
t.Run("not-equal", func(t *testing.T) {
assert := assert.New(t)
s := TestDefaultSession(t, conn, wrapper, iamRepo)
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
c2 := TestConnection(t, conn, s.PublicId, "127.0.0.1", 80, "127.0.0.1", 8080)
cp := c.Clone()
assert.NotEqual(cp.(*Connection), c2)
})
}
func TestConnection_SetTableName(t *testing.T) {
t.Parallel()
defaultTableName := defaultConnectionTableName
tests := []struct {
name string
setNameTo string
want string
}{
{
name: "new-name",
setNameTo: "new-name",
want: "new-name",
},
{
name: "reset to default",
setNameTo: "",
want: defaultTableName,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
def := AllocConnection()
require.Equal(defaultTableName, def.TableName())
c := AllocConnection()
c.SetTableName(tt.setNameTo)
assert.Equal(tt.want, c.TableName())
})
}
}
Loading…
Cancel
Save