From 76eda01d66952ae6ef0986fab6273c30e5cef1ec Mon Sep 17 00:00:00 2001 From: Jim Lambert Date: Sat, 12 Sep 2020 19:56:42 -0400 Subject: [PATCH] add Connection type with unit tests --- internal/session/connection.go | 179 +++++++++++++++++++ internal/session/connection_test.go | 255 ++++++++++++++++++++++++++++ 2 files changed, 434 insertions(+) create mode 100644 internal/session/connection.go create mode 100644 internal/session/connection_test.go diff --git a/internal/session/connection.go b/internal/session/connection.go new file mode 100644 index 0000000000..4376ee931a --- /dev/null +++ b/internal/session/connection.go @@ -0,0 +1,179 @@ +package session + +import ( + "context" + "fmt" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/timestamp" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ( + defaultConnectionTableName = "session_connection" +) + +// Session contains information about a user's session with a target +type Connection struct { + // PublicId is used to access the connection via an API + PublicId string `json:"public_id,omitempty" gorm:"primary_key"` + // SessionId of the connection + SessionId string `json:"session_id,omitempty" gorm:"default:null"` + // ClientAddress of the connection + ClientAddress string `json:"client_address,omitempty" gorm:"default:null"` + // ClientPort of the connection + ClientPort uint32 `json:"client_port,omitempty" gorm:"default:null"` + // BackendAddress of the connection + BackendAddress string `json:"backend_address,omitempty" gorm:"default:null"` + // BackendPort of the connection + BackendPort uint32 `json:"backend_port,omitempty" gorm:"default:null"` + // BytesUp of the connection + BytesUp uint64 `json:"bytes_up,omitempty" gorm:"default:null"` + // BytesDown of the connection + BytesDown uint64 `json:"bytes_down,omitempty" gorm:"default:null"` + // ClosedReason of the conneciont + ClosedReason string `json:"closed_reason,omitempty" gorm:"default:null"` + // CreateTime from the RDBMS + CreateTime *timestamp.Timestamp `json:"create_time,omitempty" gorm:"default:current_timestamp"` + // UpdateTime from the RDBMS + UpdateTime *timestamp.Timestamp `json:"update_time,omitempty" gorm:"default:current_timestamp"` + // Version of the connection + Version uint32 `json:"version,omitempty" gorm:"default:null"` + + tableName string `gorm:"-"` +} + +func (c *Connection) GetPublicId() string { + return c.PublicId +} + +var _ Cloneable = (*Connection)(nil) +var _ db.VetForWriter = (*Connection)(nil) + +// New creates a new in memory session. No options +// are currently supported. +func NewConnection(sessionID, clientAddress string, clientPort uint32, backendAddr string, backendPort uint32, opt ...Option) (*Connection, error) { + c := Connection{ + SessionId: sessionID, + ClientAddress: clientAddress, + ClientPort: clientPort, + BackendAddress: backendAddr, + BackendPort: backendPort, + } + if err := c.validateNewConnection("new connection:"); err != nil { + return nil, err + } + return &c, nil +} + +// AllocConnection will allocate a Session +func AllocConnection() Connection { + return Connection{} +} + +// Clone creates a clone of the Session +func (c *Connection) Clone() interface{} { + clone := &Connection{ + PublicId: c.PublicId, + SessionId: c.SessionId, + ClientAddress: c.ClientAddress, + ClientPort: c.ClientPort, + BackendAddress: c.BackendAddress, + BackendPort: c.BackendPort, + BytesUp: c.BytesUp, + BytesDown: c.BytesDown, + ClosedReason: c.ClosedReason, + Version: c.Version, + } + if c.CreateTime != nil { + clone.CreateTime = ×tamp.Timestamp{ + Timestamp: ×tamppb.Timestamp{ + Seconds: c.CreateTime.Timestamp.Seconds, + Nanos: c.CreateTime.Timestamp.Nanos, + }, + } + } + if c.UpdateTime != nil { + clone.UpdateTime = ×tamp.Timestamp{ + Timestamp: ×tamppb.Timestamp{ + Seconds: c.UpdateTime.Timestamp.Seconds, + Nanos: c.UpdateTime.Timestamp.Nanos, + }, + } + } + return clone +} + +// VetForWrite implements db.VetForWrite() interface and validates the connection +// before it's written. +func (c *Connection) VetForWrite(ctx context.Context, r db.Reader, opType db.OpType, opt ...db.Option) error { + opts := db.GetOpts(opt...) + if c.PublicId == "" { + return fmt.Errorf("connection vet for write: missing public id: %w", db.ErrInvalidParameter) + } + switch opType { + case db.CreateOp: + if err := c.validateNewConnection("connection vet for write:"); err != nil { + return err + } + case db.UpdateOp: + switch { + case contains(opts.WithFieldMaskPaths, "PublicId"): + return fmt.Errorf("connection vet for write: public id is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "SessionId"): + return fmt.Errorf("connection vet for write: session id is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "ClientAddress"): + return fmt.Errorf("connection vet for write: client address is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "ClientPort"): + return fmt.Errorf("connection vet for write: client port is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "BackendAddress"): + return fmt.Errorf("connection vet for write: backend address is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "BackendPort"): + return fmt.Errorf("connection vet for write: backend port is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "CreateTime"): + return fmt.Errorf("connection vet for write: create time is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "UpdateTime"): + return fmt.Errorf("connection vet for write: update time is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "ClosedReason"): + if _, err := convertToClosedReason(c.ClosedReason); err != nil { + return fmt.Errorf("connection vet for write: %w", db.ErrInvalidParameter) + } + } + } + return nil +} + +// TableName returns the tablename to override the default gorm table name +func (c *Connection) TableName() string { + if c.tableName != "" { + return c.tableName + } + return defaultConnectionTableName +} + +// SetTableName sets the tablename and satisfies the ReplayableMessage +// interface. If the caller attempts to set the name to "" the name will be +// reset to the default name. +func (c *Connection) SetTableName(n string) { + c.tableName = n +} + +// validateNewConnection checks everything but the connection's PublicId +func (c *Connection) validateNewConnection(errorPrefix string) error { + if c.SessionId == "" { + return fmt.Errorf("%s missing session id: %w", errorPrefix, db.ErrInvalidParameter) + } + if c.ClientAddress == "" { + return fmt.Errorf("%s missing client address: %w", errorPrefix, db.ErrInvalidParameter) + } + if c.ClientPort == 0 { + return fmt.Errorf("%s missing client port: %w", errorPrefix, db.ErrInvalidParameter) + } + if c.BackendAddress == "" { + return fmt.Errorf("%s missing backend address: %w", errorPrefix, db.ErrInvalidParameter) + } + if c.BackendPort == 0 { + return fmt.Errorf("%s missing backend port: %w", errorPrefix, db.ErrInvalidParameter) + } + return nil +} diff --git a/internal/session/connection_test.go b/internal/session/connection_test.go new file mode 100644 index 0000000000..fea3963841 --- /dev/null +++ b/internal/session/connection_test.go @@ -0,0 +1,255 @@ +package session + +import ( + "context" + "errors" + "testing" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/iam" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConnection_Create(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + s := TestDefaultSession(t, conn, wrapper, iamRepo) + + type args struct { + sessionId string + clientAddress string + clientPort uint32 + backendAddress string + backendPort uint32 + } + tests := []struct { + name string + args args + want *Connection + wantErr bool + wantIsErr error + create bool + wantCreateErr bool + }{ + { + name: "valid", + args: args{ + sessionId: s.PublicId, + clientAddress: "127.0.0.1", + clientPort: 22, + backendAddress: "127.0.0.1", + backendPort: 2222, + }, + want: &Connection{ + SessionId: s.PublicId, + ClientAddress: "127.0.0.1", + ClientPort: 22, + BackendAddress: "127.0.0.1", + BackendPort: 2222, + }, + create: true, + }, + { + name: "empty-session-id", + args: args{ + clientAddress: "127.0.0.1", + clientPort: 22, + backendAddress: "127.0.0.1", + backendPort: 2222, + }, + wantErr: true, + wantIsErr: db.ErrInvalidParameter, + }, + { + name: "empty-client-address", + args: args{ + sessionId: s.PublicId, + clientPort: 22, + backendAddress: "127.0.0.1", + backendPort: 2222, + }, + wantErr: true, + wantIsErr: db.ErrInvalidParameter, + }, + { + name: "empty-client-port", + args: args{ + sessionId: s.PublicId, + clientAddress: "localhost", + backendAddress: "127.0.0.1", + backendPort: 2222, + }, + wantErr: true, + wantIsErr: db.ErrInvalidParameter, + }, + { + name: "empty-backend-address", + args: args{ + sessionId: s.PublicId, + clientAddress: "localhost", + clientPort: 22, + backendPort: 2222, + }, + wantErr: true, + wantIsErr: db.ErrInvalidParameter, + }, + { + name: "empty-backend-port", + args: args{ + sessionId: s.PublicId, + clientAddress: "localhost", + clientPort: 22, + backendAddress: "127.0.0.1", + }, + wantErr: true, + wantIsErr: db.ErrInvalidParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + got, err := NewConnection( + tt.args.sessionId, + tt.args.clientAddress, + tt.args.clientPort, + tt.args.backendAddress, + tt.args.backendPort, + ) + if tt.wantErr { + require.Error(err) + assert.True(errors.Is(err, tt.wantIsErr)) + return + } + require.NoError(err) + assert.Equal(tt.want, got) + if tt.create { + id, err := db.NewPublicId(ConnectionPrefix) + require.NoError(err) + got.PublicId = id + err = db.New(conn).Create(context.Background(), got) + if tt.wantCreateErr { + assert.Error(err) + return + } else { + assert.NoError(err) + } + } + }) + } +} + +func TestConnection_Delete(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + s := TestDefaultSession(t, conn, wrapper, iamRepo) + + tests := []struct { + name string + connection *Connection + wantRowsDeleted int + wantErr bool + wantErrMsg string + }{ + { + name: "valid", + connection: TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222), + wantErr: false, + wantRowsDeleted: 1, + }, + { + name: "bad-id", + connection: func() *Connection { + c := AllocConnection() + id, err := db.NewPublicId(ConnectionPrefix) + require.NoError(t, err) + c.PublicId = id + return &c + }(), + wantErr: false, + wantRowsDeleted: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + deleteConnection := AllocConnection() + deleteConnection.PublicId = tt.connection.PublicId + deletedRows, err := rw.Delete(context.Background(), &deleteConnection) + if tt.wantErr { + require.Error(err) + return + } + require.NoError(err) + if tt.wantRowsDeleted == 0 { + assert.Equal(tt.wantRowsDeleted, deletedRows) + return + } + assert.Equal(tt.wantRowsDeleted, deletedRows) + foundConnection := AllocConnection() + foundConnection.PublicId = tt.connection.PublicId + err = rw.LookupById(context.Background(), &foundConnection) + require.Error(err) + assert.True(errors.Is(db.ErrRecordNotFound, err)) + }) + } +} + +func TestConnection_Clone(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + t.Run("valid", func(t *testing.T) { + assert := assert.New(t) + s := TestDefaultSession(t, conn, wrapper, iamRepo) + c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) + cp := c.Clone() + assert.Equal(cp.(*Connection), c) + }) + t.Run("not-equal", func(t *testing.T) { + assert := assert.New(t) + s := TestDefaultSession(t, conn, wrapper, iamRepo) + c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) + c2 := TestConnection(t, conn, s.PublicId, "127.0.0.1", 80, "127.0.0.1", 8080) + + cp := c.Clone() + assert.NotEqual(cp.(*Connection), c2) + }) +} + +func TestConnection_SetTableName(t *testing.T) { + t.Parallel() + defaultTableName := defaultConnectionTableName + tests := []struct { + name string + setNameTo string + want string + }{ + { + name: "new-name", + setNameTo: "new-name", + want: "new-name", + }, + { + name: "reset to default", + setNameTo: "", + want: defaultTableName, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + def := AllocConnection() + require.Equal(defaultTableName, def.TableName()) + c := AllocConnection() + c.SetTableName(tt.setNameTo) + assert.Equal(tt.want, c.TableName()) + }) + } +}