mirror of https://github.com/hashicorp/boundary
parent
9d43b47c65
commit
76eda01d66
@ -0,0 +1,179 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/db"
|
||||
"github.com/hashicorp/boundary/internal/db/timestamp"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultConnectionTableName = "session_connection"
|
||||
)
|
||||
|
||||
// Session contains information about a user's session with a target
|
||||
type Connection struct {
|
||||
// PublicId is used to access the connection via an API
|
||||
PublicId string `json:"public_id,omitempty" gorm:"primary_key"`
|
||||
// SessionId of the connection
|
||||
SessionId string `json:"session_id,omitempty" gorm:"default:null"`
|
||||
// ClientAddress of the connection
|
||||
ClientAddress string `json:"client_address,omitempty" gorm:"default:null"`
|
||||
// ClientPort of the connection
|
||||
ClientPort uint32 `json:"client_port,omitempty" gorm:"default:null"`
|
||||
// BackendAddress of the connection
|
||||
BackendAddress string `json:"backend_address,omitempty" gorm:"default:null"`
|
||||
// BackendPort of the connection
|
||||
BackendPort uint32 `json:"backend_port,omitempty" gorm:"default:null"`
|
||||
// BytesUp of the connection
|
||||
BytesUp uint64 `json:"bytes_up,omitempty" gorm:"default:null"`
|
||||
// BytesDown of the connection
|
||||
BytesDown uint64 `json:"bytes_down,omitempty" gorm:"default:null"`
|
||||
// ClosedReason of the conneciont
|
||||
ClosedReason string `json:"closed_reason,omitempty" gorm:"default:null"`
|
||||
// CreateTime from the RDBMS
|
||||
CreateTime *timestamp.Timestamp `json:"create_time,omitempty" gorm:"default:current_timestamp"`
|
||||
// UpdateTime from the RDBMS
|
||||
UpdateTime *timestamp.Timestamp `json:"update_time,omitempty" gorm:"default:current_timestamp"`
|
||||
// Version of the connection
|
||||
Version uint32 `json:"version,omitempty" gorm:"default:null"`
|
||||
|
||||
tableName string `gorm:"-"`
|
||||
}
|
||||
|
||||
func (c *Connection) GetPublicId() string {
|
||||
return c.PublicId
|
||||
}
|
||||
|
||||
var _ Cloneable = (*Connection)(nil)
|
||||
var _ db.VetForWriter = (*Connection)(nil)
|
||||
|
||||
// New creates a new in memory session. No options
|
||||
// are currently supported.
|
||||
func NewConnection(sessionID, clientAddress string, clientPort uint32, backendAddr string, backendPort uint32, opt ...Option) (*Connection, error) {
|
||||
c := Connection{
|
||||
SessionId: sessionID,
|
||||
ClientAddress: clientAddress,
|
||||
ClientPort: clientPort,
|
||||
BackendAddress: backendAddr,
|
||||
BackendPort: backendPort,
|
||||
}
|
||||
if err := c.validateNewConnection("new connection:"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// AllocConnection will allocate a Session
|
||||
func AllocConnection() Connection {
|
||||
return Connection{}
|
||||
}
|
||||
|
||||
// Clone creates a clone of the Session
|
||||
func (c *Connection) Clone() interface{} {
|
||||
clone := &Connection{
|
||||
PublicId: c.PublicId,
|
||||
SessionId: c.SessionId,
|
||||
ClientAddress: c.ClientAddress,
|
||||
ClientPort: c.ClientPort,
|
||||
BackendAddress: c.BackendAddress,
|
||||
BackendPort: c.BackendPort,
|
||||
BytesUp: c.BytesUp,
|
||||
BytesDown: c.BytesDown,
|
||||
ClosedReason: c.ClosedReason,
|
||||
Version: c.Version,
|
||||
}
|
||||
if c.CreateTime != nil {
|
||||
clone.CreateTime = ×tamp.Timestamp{
|
||||
Timestamp: ×tamppb.Timestamp{
|
||||
Seconds: c.CreateTime.Timestamp.Seconds,
|
||||
Nanos: c.CreateTime.Timestamp.Nanos,
|
||||
},
|
||||
}
|
||||
}
|
||||
if c.UpdateTime != nil {
|
||||
clone.UpdateTime = ×tamp.Timestamp{
|
||||
Timestamp: ×tamppb.Timestamp{
|
||||
Seconds: c.UpdateTime.Timestamp.Seconds,
|
||||
Nanos: c.UpdateTime.Timestamp.Nanos,
|
||||
},
|
||||
}
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
// VetForWrite implements db.VetForWrite() interface and validates the connection
|
||||
// before it's written.
|
||||
func (c *Connection) VetForWrite(ctx context.Context, r db.Reader, opType db.OpType, opt ...db.Option) error {
|
||||
opts := db.GetOpts(opt...)
|
||||
if c.PublicId == "" {
|
||||
return fmt.Errorf("connection vet for write: missing public id: %w", db.ErrInvalidParameter)
|
||||
}
|
||||
switch opType {
|
||||
case db.CreateOp:
|
||||
if err := c.validateNewConnection("connection vet for write:"); err != nil {
|
||||
return err
|
||||
}
|
||||
case db.UpdateOp:
|
||||
switch {
|
||||
case contains(opts.WithFieldMaskPaths, "PublicId"):
|
||||
return fmt.Errorf("connection vet for write: public id is immutable: %w", db.ErrInvalidParameter)
|
||||
case contains(opts.WithFieldMaskPaths, "SessionId"):
|
||||
return fmt.Errorf("connection vet for write: session id is immutable: %w", db.ErrInvalidParameter)
|
||||
case contains(opts.WithFieldMaskPaths, "ClientAddress"):
|
||||
return fmt.Errorf("connection vet for write: client address is immutable: %w", db.ErrInvalidParameter)
|
||||
case contains(opts.WithFieldMaskPaths, "ClientPort"):
|
||||
return fmt.Errorf("connection vet for write: client port is immutable: %w", db.ErrInvalidParameter)
|
||||
case contains(opts.WithFieldMaskPaths, "BackendAddress"):
|
||||
return fmt.Errorf("connection vet for write: backend address is immutable: %w", db.ErrInvalidParameter)
|
||||
case contains(opts.WithFieldMaskPaths, "BackendPort"):
|
||||
return fmt.Errorf("connection vet for write: backend port is immutable: %w", db.ErrInvalidParameter)
|
||||
case contains(opts.WithFieldMaskPaths, "CreateTime"):
|
||||
return fmt.Errorf("connection vet for write: create time is immutable: %w", db.ErrInvalidParameter)
|
||||
case contains(opts.WithFieldMaskPaths, "UpdateTime"):
|
||||
return fmt.Errorf("connection vet for write: update time is immutable: %w", db.ErrInvalidParameter)
|
||||
case contains(opts.WithFieldMaskPaths, "ClosedReason"):
|
||||
if _, err := convertToClosedReason(c.ClosedReason); err != nil {
|
||||
return fmt.Errorf("connection vet for write: %w", db.ErrInvalidParameter)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableName returns the tablename to override the default gorm table name
|
||||
func (c *Connection) TableName() string {
|
||||
if c.tableName != "" {
|
||||
return c.tableName
|
||||
}
|
||||
return defaultConnectionTableName
|
||||
}
|
||||
|
||||
// SetTableName sets the tablename and satisfies the ReplayableMessage
|
||||
// interface. If the caller attempts to set the name to "" the name will be
|
||||
// reset to the default name.
|
||||
func (c *Connection) SetTableName(n string) {
|
||||
c.tableName = n
|
||||
}
|
||||
|
||||
// validateNewConnection checks everything but the connection's PublicId
|
||||
func (c *Connection) validateNewConnection(errorPrefix string) error {
|
||||
if c.SessionId == "" {
|
||||
return fmt.Errorf("%s missing session id: %w", errorPrefix, db.ErrInvalidParameter)
|
||||
}
|
||||
if c.ClientAddress == "" {
|
||||
return fmt.Errorf("%s missing client address: %w", errorPrefix, db.ErrInvalidParameter)
|
||||
}
|
||||
if c.ClientPort == 0 {
|
||||
return fmt.Errorf("%s missing client port: %w", errorPrefix, db.ErrInvalidParameter)
|
||||
}
|
||||
if c.BackendAddress == "" {
|
||||
return fmt.Errorf("%s missing backend address: %w", errorPrefix, db.ErrInvalidParameter)
|
||||
}
|
||||
if c.BackendPort == 0 {
|
||||
return fmt.Errorf("%s missing backend port: %w", errorPrefix, db.ErrInvalidParameter)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,255 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/db"
|
||||
"github.com/hashicorp/boundary/internal/iam"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnection_Create(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
s := TestDefaultSession(t, conn, wrapper, iamRepo)
|
||||
|
||||
type args struct {
|
||||
sessionId string
|
||||
clientAddress string
|
||||
clientPort uint32
|
||||
backendAddress string
|
||||
backendPort uint32
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Connection
|
||||
wantErr bool
|
||||
wantIsErr error
|
||||
create bool
|
||||
wantCreateErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
clientAddress: "127.0.0.1",
|
||||
clientPort: 22,
|
||||
backendAddress: "127.0.0.1",
|
||||
backendPort: 2222,
|
||||
},
|
||||
want: &Connection{
|
||||
SessionId: s.PublicId,
|
||||
ClientAddress: "127.0.0.1",
|
||||
ClientPort: 22,
|
||||
BackendAddress: "127.0.0.1",
|
||||
BackendPort: 2222,
|
||||
},
|
||||
create: true,
|
||||
},
|
||||
{
|
||||
name: "empty-session-id",
|
||||
args: args{
|
||||
clientAddress: "127.0.0.1",
|
||||
clientPort: 22,
|
||||
backendAddress: "127.0.0.1",
|
||||
backendPort: 2222,
|
||||
},
|
||||
wantErr: true,
|
||||
wantIsErr: db.ErrInvalidParameter,
|
||||
},
|
||||
{
|
||||
name: "empty-client-address",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
clientPort: 22,
|
||||
backendAddress: "127.0.0.1",
|
||||
backendPort: 2222,
|
||||
},
|
||||
wantErr: true,
|
||||
wantIsErr: db.ErrInvalidParameter,
|
||||
},
|
||||
{
|
||||
name: "empty-client-port",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
clientAddress: "localhost",
|
||||
backendAddress: "127.0.0.1",
|
||||
backendPort: 2222,
|
||||
},
|
||||
wantErr: true,
|
||||
wantIsErr: db.ErrInvalidParameter,
|
||||
},
|
||||
{
|
||||
name: "empty-backend-address",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
clientAddress: "localhost",
|
||||
clientPort: 22,
|
||||
backendPort: 2222,
|
||||
},
|
||||
wantErr: true,
|
||||
wantIsErr: db.ErrInvalidParameter,
|
||||
},
|
||||
{
|
||||
name: "empty-backend-port",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
clientAddress: "localhost",
|
||||
clientPort: 22,
|
||||
backendAddress: "127.0.0.1",
|
||||
},
|
||||
wantErr: true,
|
||||
wantIsErr: db.ErrInvalidParameter,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
got, err := NewConnection(
|
||||
tt.args.sessionId,
|
||||
tt.args.clientAddress,
|
||||
tt.args.clientPort,
|
||||
tt.args.backendAddress,
|
||||
tt.args.backendPort,
|
||||
)
|
||||
if tt.wantErr {
|
||||
require.Error(err)
|
||||
assert.True(errors.Is(err, tt.wantIsErr))
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.Equal(tt.want, got)
|
||||
if tt.create {
|
||||
id, err := db.NewPublicId(ConnectionPrefix)
|
||||
require.NoError(err)
|
||||
got.PublicId = id
|
||||
err = db.New(conn).Create(context.Background(), got)
|
||||
if tt.wantCreateErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnection_Delete(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
rw := db.New(conn)
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
s := TestDefaultSession(t, conn, wrapper, iamRepo)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
connection *Connection
|
||||
wantRowsDeleted int
|
||||
wantErr bool
|
||||
wantErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
connection: TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222),
|
||||
wantErr: false,
|
||||
wantRowsDeleted: 1,
|
||||
},
|
||||
{
|
||||
name: "bad-id",
|
||||
connection: func() *Connection {
|
||||
c := AllocConnection()
|
||||
id, err := db.NewPublicId(ConnectionPrefix)
|
||||
require.NoError(t, err)
|
||||
c.PublicId = id
|
||||
return &c
|
||||
}(),
|
||||
wantErr: false,
|
||||
wantRowsDeleted: 0,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
deleteConnection := AllocConnection()
|
||||
deleteConnection.PublicId = tt.connection.PublicId
|
||||
deletedRows, err := rw.Delete(context.Background(), &deleteConnection)
|
||||
if tt.wantErr {
|
||||
require.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
if tt.wantRowsDeleted == 0 {
|
||||
assert.Equal(tt.wantRowsDeleted, deletedRows)
|
||||
return
|
||||
}
|
||||
assert.Equal(tt.wantRowsDeleted, deletedRows)
|
||||
foundConnection := AllocConnection()
|
||||
foundConnection.PublicId = tt.connection.PublicId
|
||||
err = rw.LookupById(context.Background(), &foundConnection)
|
||||
require.Error(err)
|
||||
assert.True(errors.Is(db.ErrRecordNotFound, err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnection_Clone(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := TestDefaultSession(t, conn, wrapper, iamRepo)
|
||||
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
|
||||
cp := c.Clone()
|
||||
assert.Equal(cp.(*Connection), c)
|
||||
})
|
||||
t.Run("not-equal", func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
s := TestDefaultSession(t, conn, wrapper, iamRepo)
|
||||
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
|
||||
c2 := TestConnection(t, conn, s.PublicId, "127.0.0.1", 80, "127.0.0.1", 8080)
|
||||
|
||||
cp := c.Clone()
|
||||
assert.NotEqual(cp.(*Connection), c2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnection_SetTableName(t *testing.T) {
|
||||
t.Parallel()
|
||||
defaultTableName := defaultConnectionTableName
|
||||
tests := []struct {
|
||||
name string
|
||||
setNameTo string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "new-name",
|
||||
setNameTo: "new-name",
|
||||
want: "new-name",
|
||||
},
|
||||
{
|
||||
name: "reset to default",
|
||||
setNameTo: "",
|
||||
want: defaultTableName,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
def := AllocConnection()
|
||||
require.Equal(defaultTableName, def.TableName())
|
||||
c := AllocConnection()
|
||||
c.SetTableName(tt.setNameTo)
|
||||
assert.Equal(tt.want, c.TableName())
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in new issue