mirror of https://github.com/hashicorp/boundary
parent
32e85ee6f0
commit
10d47817b9
@ -0,0 +1,297 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/db"
|
||||
dbcommon "github.com/hashicorp/boundary/internal/db/common"
|
||||
)
|
||||
|
||||
// CreateConnection inserts into the repository and returns the new Connection with
|
||||
// its State of "Pending". The following fields must be empty when creating a
|
||||
// session: PublicId, BytesUp, BytesDown, ClosedReason, Version, CreateTime,
|
||||
// UpdateTime. No options are currently supported.
|
||||
func (r *Repository) CreateConnection(ctx context.Context, newConnection *Connection, opt ...Option) (*Connection, *ConnectionState, error) {
|
||||
if newConnection == nil {
|
||||
return nil, nil, fmt.Errorf("create connection: missing connection: %w", db.ErrInvalidParameter)
|
||||
}
|
||||
if newConnection.PublicId != "" {
|
||||
return nil, nil, fmt.Errorf("create connection: public id is not empty: %w", db.ErrInvalidParameter)
|
||||
}
|
||||
if newConnection.BytesUp != 0 {
|
||||
return nil, nil, fmt.Errorf("create connection: bytes down is not empty: %w", db.ErrInvalidParameter)
|
||||
}
|
||||
if newConnection.BytesDown != 0 {
|
||||
return nil, nil, fmt.Errorf("create connection: bytes up is not empty: %w", db.ErrInvalidParameter)
|
||||
}
|
||||
if newConnection.ClosedReason != "" {
|
||||
return nil, nil, fmt.Errorf("create connection: closed reason is not empty: %w", db.ErrInvalidParameter)
|
||||
}
|
||||
if newConnection.Version != 0 {
|
||||
return nil, nil, fmt.Errorf("create connection: version is not empty: %w", db.ErrInvalidParameter)
|
||||
}
|
||||
if newConnection.CreateTime != nil {
|
||||
return nil, nil, fmt.Errorf("create connection: create time is not empty: %w", db.ErrInvalidParameter)
|
||||
}
|
||||
if newConnection.UpdateTime != nil {
|
||||
return nil, nil, fmt.Errorf("create connection: update time is not empty: %w", db.ErrInvalidParameter)
|
||||
}
|
||||
|
||||
id, err := newConnectionId()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create connection: %w", err)
|
||||
}
|
||||
newConnection.PublicId = id
|
||||
|
||||
var returnedConnection *Connection
|
||||
var returnedState *ConnectionState
|
||||
_, err = r.writer.DoTx(
|
||||
ctx,
|
||||
db.StdRetryCnt,
|
||||
db.ExpBackoff{},
|
||||
func(read db.Reader, w db.Writer) error {
|
||||
returnedConnection = newConnection.Clone().(*Connection)
|
||||
if err = w.Create(ctx, returnedConnection); err != nil {
|
||||
return err
|
||||
}
|
||||
var foundStates []*ConnectionState
|
||||
// trigger will create new "Pending" state
|
||||
if foundStates, err = fetchConnectionStates(ctx, read, returnedConnection.PublicId); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(foundStates) != 1 {
|
||||
return fmt.Errorf("%d states found for new connection %s", len(foundStates), returnedConnection.PublicId)
|
||||
}
|
||||
returnedState = foundStates[0]
|
||||
if returnedState.Status != StatusConnected.String() {
|
||||
return fmt.Errorf("new connection %s state is not valid: %s", returnedConnection.PublicId, returnedState.Status)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create connection: %w", err)
|
||||
}
|
||||
return returnedConnection, returnedState, err
|
||||
}
|
||||
|
||||
// LookupConnection will look up a connection in the repository and return the connection
|
||||
// with its states. If the connection is not found, it will return nil, nil, nil.
|
||||
// No options are currently supported.
|
||||
func (r *Repository) LookupConnection(ctx context.Context, connectionId string, opt ...Option) (*Connection, []*ConnectionState, error) {
|
||||
if connectionId == "" {
|
||||
return nil, nil, fmt.Errorf("lookup connection: missing connectionId id: %w", db.ErrInvalidParameter)
|
||||
}
|
||||
connection := AllocConnection()
|
||||
connection.PublicId = connectionId
|
||||
var states []*ConnectionState
|
||||
_, err := r.writer.DoTx(
|
||||
ctx,
|
||||
db.StdRetryCnt,
|
||||
db.ExpBackoff{},
|
||||
func(read db.Reader, w db.Writer) error {
|
||||
if err := read.LookupById(ctx, &connection); err != nil {
|
||||
return fmt.Errorf("lookup connection: failed %w for %s", err, connectionId)
|
||||
}
|
||||
var err error
|
||||
if states, err = fetchConnectionStates(ctx, read, connectionId, db.WithOrder("start_time desc")); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrRecordNotFound) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
return nil, nil, fmt.Errorf("lookup connection: %w", err)
|
||||
}
|
||||
return &connection, states, nil
|
||||
}
|
||||
|
||||
// ListConnections will sessions. Supports the WithLimit and WithOrder options.
|
||||
func (r *Repository) ListConnections(ctx context.Context, sessionId string, opt ...Option) ([]*Connection, error) {
|
||||
var connections []*Connection
|
||||
err := r.list(ctx, &connections, "session_id = ?", []interface{}{sessionId}, opt...) // pass options, so WithLimit and WithOrder are supported
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list connections: %w", err)
|
||||
}
|
||||
return connections, nil
|
||||
}
|
||||
|
||||
// DeleteConnection will delete a connection from the repository.
|
||||
func (r *Repository) DeleteConnection(ctx context.Context, publicId string, opt ...Option) (int, error) {
|
||||
if publicId == "" {
|
||||
return db.NoRowsAffected, fmt.Errorf("delete connection: missing public id %w", db.ErrInvalidParameter)
|
||||
}
|
||||
connection := AllocConnection()
|
||||
connection.PublicId = publicId
|
||||
if err := r.reader.LookupByPublicId(ctx, &connection); err != nil {
|
||||
return db.NoRowsAffected, fmt.Errorf("delete connection: failed %w for %s", err, publicId)
|
||||
}
|
||||
|
||||
var rowsDeleted int
|
||||
_, err := r.writer.DoTx(
|
||||
ctx,
|
||||
db.StdRetryCnt,
|
||||
db.ExpBackoff{},
|
||||
func(_ db.Reader, w db.Writer) error {
|
||||
deleteSession := connection.Clone()
|
||||
var err error
|
||||
rowsDeleted, err = w.Delete(
|
||||
ctx,
|
||||
deleteSession,
|
||||
)
|
||||
if err == nil && rowsDeleted > 1 {
|
||||
// return err, which will result in a rollback of the delete
|
||||
return errors.New("error more than 1 connection would have been deleted")
|
||||
}
|
||||
return err
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return db.NoRowsAffected, fmt.Errorf("delete connection: failed %w for %s", err, publicId)
|
||||
}
|
||||
return rowsDeleted, nil
|
||||
}
|
||||
|
||||
// UpdateConnection updates the repository entry for the connection, using the
|
||||
// fieldMaskPaths. Only BytesUp, BytesDown, and ClosedReason are muttable and
|
||||
// will be set to NULL if set to a zero value and included in the fieldMaskPaths.
|
||||
func (r *Repository) UpdateConnection(ctx context.Context, connection *Connection, version uint32, fieldMaskPaths []string, opt ...Option) (*Connection, []*ConnectionState, int, error) {
|
||||
if connection == nil {
|
||||
return nil, nil, db.NoRowsAffected, fmt.Errorf("update connection: missing connection %w", db.ErrInvalidParameter)
|
||||
}
|
||||
if connection.PublicId == "" {
|
||||
return nil, nil, db.NoRowsAffected, fmt.Errorf("update connection: missing connection public id %w", db.ErrInvalidParameter)
|
||||
}
|
||||
for _, f := range fieldMaskPaths {
|
||||
switch {
|
||||
case strings.EqualFold("BytesUp", f):
|
||||
case strings.EqualFold("BytesDown", f):
|
||||
case strings.EqualFold("ClosedReason", f):
|
||||
default:
|
||||
return nil, nil, db.NoRowsAffected, fmt.Errorf("update connection: field: %s: %w", f, db.ErrInvalidFieldMask)
|
||||
}
|
||||
}
|
||||
var dbMask, nullFields []string
|
||||
dbMask, nullFields = dbcommon.BuildUpdatePaths(
|
||||
map[string]interface{}{
|
||||
"BytesUp": connection.BytesUp,
|
||||
"BytesDown": connection.BytesDown,
|
||||
"ClosedReason": connection.ClosedReason,
|
||||
},
|
||||
fieldMaskPaths,
|
||||
)
|
||||
if len(dbMask) == 0 && len(nullFields) == 0 {
|
||||
return nil, nil, db.NoRowsAffected, fmt.Errorf("update connection: %w", db.ErrEmptyFieldMask)
|
||||
}
|
||||
|
||||
var c *Connection
|
||||
var states []*ConnectionState
|
||||
var rowsUpdated int
|
||||
_, err := r.writer.DoTx(
|
||||
ctx,
|
||||
db.StdRetryCnt,
|
||||
db.ExpBackoff{},
|
||||
func(reader db.Reader, w db.Writer) error {
|
||||
var err error
|
||||
c = connection.Clone().(*Connection)
|
||||
rowsUpdated, err = w.Update(
|
||||
ctx,
|
||||
c,
|
||||
dbMask,
|
||||
nullFields,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err == nil && rowsUpdated > 1 {
|
||||
// return err, which will result in a rollback of the update
|
||||
return errors.New("error more than 1 connection would have been updated ")
|
||||
}
|
||||
states, err = fetchConnectionStates(ctx, reader, c.PublicId, db.WithOrder("start_time desc"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, db.NoRowsAffected, fmt.Errorf("update connection: %w for %s", err, connection.PublicId)
|
||||
}
|
||||
return c, states, rowsUpdated, err
|
||||
}
|
||||
|
||||
// UpdateConnectionState will update the connection's state using the connection id and its
|
||||
// version. No options are currently supported.
|
||||
func (r *Repository) UpdateConnectionState(ctx context.Context, connectionId string, connectionVersion uint32, s ConnectionStatus, opt ...Option) (*Connection, []*ConnectionState, error) {
|
||||
if connectionId == "" {
|
||||
return nil, nil, fmt.Errorf("update connection state: missing session id %w", db.ErrInvalidParameter)
|
||||
}
|
||||
if connectionVersion == 0 {
|
||||
return nil, nil, fmt.Errorf("update connection state: version cannot be zero: %w", db.ErrInvalidParameter)
|
||||
}
|
||||
if s == "" {
|
||||
return nil, nil, fmt.Errorf("update connection state: missing connection status: %w", db.ErrInvalidParameter)
|
||||
}
|
||||
|
||||
newState, err := NewConnectionState(connectionId, s)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("update connection state: %w", err)
|
||||
}
|
||||
sessionConnection, _, err := r.LookupConnection(ctx, connectionId)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("update connection state: %w", err)
|
||||
}
|
||||
if sessionConnection == nil {
|
||||
return nil, nil, fmt.Errorf("update connection state: unable to look up connection for %s: %w", connectionId, err)
|
||||
}
|
||||
|
||||
updatedConnection := AllocConnection()
|
||||
var returnedStates []*ConnectionState
|
||||
_, err = r.writer.DoTx(
|
||||
ctx,
|
||||
db.StdRetryCnt,
|
||||
db.ExpBackoff{},
|
||||
func(reader db.Reader, w db.Writer) error {
|
||||
// We need to update the session version as that's the aggregate
|
||||
updatedConnection.PublicId = connectionId
|
||||
updatedConnection.Version = uint32(connectionVersion) + 1
|
||||
rowsUpdated, err := w.Update(ctx, &updatedConnection, []string{"Version"}, nil, db.WithVersion(&connectionVersion))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to update connection version: %w", err)
|
||||
}
|
||||
if rowsUpdated != 1 {
|
||||
return fmt.Errorf("updated connection and %d rows updated", rowsUpdated)
|
||||
}
|
||||
if err := w.Create(ctx, newState); err != nil {
|
||||
return fmt.Errorf("unable to add new state: %w", err)
|
||||
}
|
||||
|
||||
returnedStates, err = fetchConnectionStates(ctx, reader, connectionId, db.WithOrder("start_time desc"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("update connection state: error creating new state: %w", err)
|
||||
}
|
||||
return &updatedConnection, returnedStates, nil
|
||||
}
|
||||
|
||||
func fetchConnectionStates(ctx context.Context, r db.Reader, connectionId string, opt ...db.Option) ([]*ConnectionState, error) {
|
||||
var states []*ConnectionState
|
||||
if err := r.SearchWhere(ctx, &states, "connection_id = ?", []interface{}{connectionId}, opt...); err != nil {
|
||||
return nil, fmt.Errorf("fetch connection states: %w", err)
|
||||
}
|
||||
if len(states) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return states, nil
|
||||
}
|
||||
@ -0,0 +1,631 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/hashicorp/boundary/internal/db"
|
||||
dbassert "github.com/hashicorp/boundary/internal/db/assert"
|
||||
"github.com/hashicorp/boundary/internal/iam"
|
||||
"github.com/hashicorp/boundary/internal/kms"
|
||||
"github.com/hashicorp/boundary/internal/oplog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRepository_ListConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
const testLimit = 10
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
rw := db.New(conn)
|
||||
kms := kms.TestKms(t, conn, wrapper)
|
||||
repo, err := NewRepository(rw, rw, kms, WithLimit(testLimit))
|
||||
require.NoError(t, err)
|
||||
session := TestDefaultSession(t, conn, wrapper, iamRepo)
|
||||
|
||||
type args struct {
|
||||
searchForSessionId string
|
||||
opt []Option
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
createCnt int
|
||||
args args
|
||||
wantCnt int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no-limit",
|
||||
createCnt: repo.defaultLimit + 1,
|
||||
args: args{
|
||||
searchForSessionId: session.PublicId,
|
||||
opt: []Option{WithLimit(-1)},
|
||||
},
|
||||
wantCnt: repo.defaultLimit + 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "default-limit",
|
||||
createCnt: repo.defaultLimit + 1,
|
||||
args: args{
|
||||
searchForSessionId: session.PublicId,
|
||||
},
|
||||
wantCnt: repo.defaultLimit,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "custom-limit",
|
||||
createCnt: repo.defaultLimit + 1,
|
||||
args: args{
|
||||
searchForSessionId: session.PublicId,
|
||||
opt: []Option{WithLimit(3)},
|
||||
},
|
||||
wantCnt: 3,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "bad-session-id",
|
||||
createCnt: repo.defaultLimit + 1,
|
||||
args: args{
|
||||
searchForSessionId: "s_thisIsNotValid",
|
||||
},
|
||||
wantCnt: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
require.NoError(conn.Where("1=1").Delete(AllocConnection()).Error)
|
||||
testConnections := []*Connection{}
|
||||
for i := 0; i < tt.createCnt; i++ {
|
||||
c := TestConnection(t, conn,
|
||||
session.PublicId,
|
||||
"127.0.0.1",
|
||||
22,
|
||||
"127.0.0.1",
|
||||
2222,
|
||||
)
|
||||
testConnections = append(testConnections, c)
|
||||
}
|
||||
assert.Equal(tt.createCnt, len(testConnections))
|
||||
got, err := repo.ListConnections(context.Background(), tt.args.searchForSessionId, tt.args.opt...)
|
||||
if tt.wantErr {
|
||||
require.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.Equal(tt.wantCnt, len(got))
|
||||
})
|
||||
}
|
||||
t.Run("withOrder", func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
require.NoError(conn.Where("1=1").Delete(AllocConnection()).Error)
|
||||
wantCnt := 5
|
||||
for i := 0; i < wantCnt; i++ {
|
||||
_ = TestConnection(t, conn,
|
||||
session.PublicId,
|
||||
"127.0.0.1",
|
||||
22,
|
||||
"127.0.0.1",
|
||||
2222,
|
||||
)
|
||||
}
|
||||
got, err := repo.ListConnections(context.Background(), session.PublicId, WithOrder("create_time asc"))
|
||||
require.NoError(err)
|
||||
assert.Equal(wantCnt, len(got))
|
||||
|
||||
for i := 0; i < len(got)-1; i++ {
|
||||
first, err := ptypes.Timestamp(got[i].CreateTime.Timestamp)
|
||||
require.NoError(err)
|
||||
second, err := ptypes.Timestamp(got[i+1].CreateTime.Timestamp)
|
||||
require.NoError(err)
|
||||
assert.True(first.Before(second))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepository_CreateConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
rw := db.New(conn)
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
kms := kms.TestKms(t, conn, wrapper)
|
||||
repo, err := NewRepository(rw, rw, kms)
|
||||
require.NoError(t, err)
|
||||
session := TestDefaultSession(t, conn, wrapper, iamRepo)
|
||||
|
||||
type args struct {
|
||||
connection *Connection
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
wantIsError error
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
connection: func() *Connection {
|
||||
c, err := NewConnection(
|
||||
session.PublicId,
|
||||
"127.0.0.1",
|
||||
22,
|
||||
"127.0.0.1",
|
||||
2222,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return c
|
||||
}(),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty-session-id",
|
||||
args: args{
|
||||
connection: &Connection{
|
||||
ClientAddress: "127.0.0.1",
|
||||
ClientPort: 22,
|
||||
BackendAddress: "127.0.0.1",
|
||||
BackendPort: 2222,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantIsError: db.ErrInvalidParameter,
|
||||
},
|
||||
{
|
||||
name: "empty-client-address",
|
||||
args: args{
|
||||
connection: &Connection{
|
||||
SessionId: session.PublicId,
|
||||
ClientPort: 22,
|
||||
BackendAddress: "127.0.0.1",
|
||||
BackendPort: 2222,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantIsError: db.ErrInvalidParameter,
|
||||
},
|
||||
{
|
||||
name: "empty-client-port",
|
||||
args: args{
|
||||
connection: &Connection{
|
||||
SessionId: session.PublicId,
|
||||
ClientAddress: "127.0.0.1",
|
||||
BackendAddress: "127.0.0.1",
|
||||
BackendPort: 2222,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantIsError: db.ErrInvalidParameter,
|
||||
},
|
||||
{
|
||||
name: "empty-backend-address",
|
||||
args: args{
|
||||
connection: &Connection{
|
||||
SessionId: session.PublicId,
|
||||
ClientAddress: "127.0.0.1",
|
||||
ClientPort: 22,
|
||||
BackendPort: 2222,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantIsError: db.ErrInvalidParameter,
|
||||
},
|
||||
{
|
||||
name: "empty-backend-port",
|
||||
args: args{
|
||||
connection: &Connection{
|
||||
SessionId: session.PublicId,
|
||||
ClientAddress: "127.0.0.1",
|
||||
ClientPort: 22,
|
||||
BackendAddress: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantIsError: db.ErrInvalidParameter,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
connection, st, err := repo.CreateConnection(context.Background(), tt.args.connection)
|
||||
if tt.wantErr {
|
||||
assert.Error(err)
|
||||
assert.Nil(connection)
|
||||
assert.Nil(st)
|
||||
if tt.wantIsError != nil {
|
||||
assert.True(errors.Is(err, tt.wantIsError))
|
||||
}
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.NotNil(connection.CreateTime)
|
||||
assert.NotNil(st.StartTime)
|
||||
assert.Equal(st.Status, StatusConnected.String())
|
||||
found, foundStates, err := repo.LookupConnection(context.Background(), connection.PublicId)
|
||||
assert.NoError(err)
|
||||
assert.Equal(found, connection)
|
||||
|
||||
err = db.TestVerifyOplog(t, rw, connection.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_CREATE), db.WithCreateNotBefore(10*time.Second))
|
||||
assert.Error(err)
|
||||
|
||||
require.Equal(1, len(foundStates))
|
||||
assert.Equal(foundStates[0].Status, StatusConnected.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepository_UpdateConnectionState(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
rw := db.New(conn)
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
kms := kms.TestKms(t, conn, wrapper)
|
||||
repo, err := NewRepository(rw, rw, kms)
|
||||
require.NoError(t, err)
|
||||
session := TestDefaultSession(t, conn, wrapper, iamRepo)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
connection *Connection
|
||||
newStatus ConnectionStatus
|
||||
overrideConnectionId *string
|
||||
overrideConnectionVersion *uint32
|
||||
wantStateCnt int
|
||||
wantErr bool
|
||||
wantIsError error
|
||||
}{
|
||||
{
|
||||
name: "closed",
|
||||
connection: TestConnection(t, conn, session.PublicId, "0.0.0.0", 22, "0.0.0.0", 2222),
|
||||
newStatus: StatusClosed,
|
||||
wantStateCnt: 2,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "bad-version",
|
||||
connection: TestConnection(t, conn, session.PublicId, "0.0.0.0", 22, "0.0.0.0", 2222),
|
||||
newStatus: StatusClosed,
|
||||
overrideConnectionVersion: func() *uint32 {
|
||||
v := uint32(22)
|
||||
return &v
|
||||
}(),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty-version",
|
||||
connection: TestConnection(t, conn, session.PublicId, "0.0.0.0", 22, "0.0.0.0", 2222),
|
||||
newStatus: StatusClosed,
|
||||
overrideConnectionVersion: func() *uint32 {
|
||||
v := uint32(0)
|
||||
return &v
|
||||
}(),
|
||||
wantErr: true,
|
||||
wantIsError: db.ErrInvalidParameter,
|
||||
},
|
||||
{
|
||||
name: "bad-connectionId",
|
||||
connection: TestConnection(t, conn, session.PublicId, "0.0.0.0", 22, "0.0.0.0", 2222),
|
||||
newStatus: StatusClosed,
|
||||
overrideConnectionId: func() *string {
|
||||
s := "sc_thisIsNotValid"
|
||||
return &s
|
||||
}(),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty-connectionId",
|
||||
connection: TestConnection(t, conn, session.PublicId, "0.0.0.0", 22, "0.0.0.0", 2222),
|
||||
newStatus: StatusClosed,
|
||||
overrideConnectionId: func() *string {
|
||||
s := ""
|
||||
return &s
|
||||
}(),
|
||||
wantErr: true,
|
||||
wantIsError: db.ErrInvalidParameter,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
var id string
|
||||
var version uint32
|
||||
switch {
|
||||
case tt.overrideConnectionId != nil:
|
||||
id = *tt.overrideConnectionId
|
||||
default:
|
||||
id = tt.connection.PublicId
|
||||
}
|
||||
switch {
|
||||
case tt.overrideConnectionVersion != nil:
|
||||
version = *tt.overrideConnectionVersion
|
||||
default:
|
||||
version = tt.connection.Version
|
||||
}
|
||||
|
||||
s, ss, err := repo.UpdateConnectionState(context.Background(), id, version, tt.newStatus)
|
||||
if tt.wantErr {
|
||||
require.Error(err)
|
||||
if tt.wantIsError != nil {
|
||||
assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error %s", err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
require.NotNil(s)
|
||||
require.NotNil(ss)
|
||||
assert.Equal(tt.wantStateCnt, len(ss))
|
||||
assert.Equal(tt.newStatus.String(), ss[0].Status)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepository_UpdateConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
rw := db.New(conn)
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
kms := kms.TestKms(t, conn, wrapper)
|
||||
repo, err := NewRepository(rw, rw, kms)
|
||||
require.NoError(t, err)
|
||||
session := TestDefaultSession(t, conn, wrapper, iamRepo)
|
||||
|
||||
type args struct {
|
||||
closedReason ClosedReason
|
||||
bytesUp uint64
|
||||
bytesDown uint64
|
||||
fieldMaskPaths []string
|
||||
opt []Option
|
||||
publicId *string // not updateable - db.ErrInvalidFieldMask
|
||||
sessionId string // not updateable - db.ErrInvalidFieldMask
|
||||
clientAddress string // not updateable - db.ErrInvalidFieldMask
|
||||
clientPort uint32 // not updateable - db.ErrInvalidFieldMask
|
||||
backendAddress string // not updateable - db.ErrInvalidFieldMask
|
||||
backendPort uint32 // not updateable - db.ErrInvalidFieldMask
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantRowsUpdate int
|
||||
wantErr bool
|
||||
wantIsError error
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
closedReason: ConnectionClosedByUser,
|
||||
bytesUp: uint64(111),
|
||||
bytesDown: uint64(1),
|
||||
fieldMaskPaths: []string{"ClosedReason", "BytesUp", "BytesDown"},
|
||||
},
|
||||
wantErr: false,
|
||||
wantRowsUpdate: 1,
|
||||
},
|
||||
{
|
||||
name: "publicId",
|
||||
args: args{
|
||||
publicId: func() *string {
|
||||
id, err := newConnectionId()
|
||||
require.NoError(t, err)
|
||||
return &id
|
||||
}(),
|
||||
fieldMaskPaths: []string{"PublicId"},
|
||||
},
|
||||
wantErr: true,
|
||||
wantRowsUpdate: 0,
|
||||
wantIsError: db.ErrInvalidFieldMask,
|
||||
},
|
||||
{
|
||||
name: "sessionId",
|
||||
args: args{
|
||||
sessionId: func() string {
|
||||
id, err := newId()
|
||||
require.NoError(t, err)
|
||||
return id
|
||||
}(),
|
||||
fieldMaskPaths: []string{"SessionId"},
|
||||
},
|
||||
wantErr: true,
|
||||
wantRowsUpdate: 0,
|
||||
wantIsError: db.ErrInvalidFieldMask,
|
||||
},
|
||||
{
|
||||
name: "clientAddress",
|
||||
args: args{
|
||||
clientAddress: "127.0.0.1",
|
||||
fieldMaskPaths: []string{"ClientAddress"},
|
||||
},
|
||||
wantErr: true,
|
||||
wantRowsUpdate: 0,
|
||||
wantIsError: db.ErrInvalidFieldMask,
|
||||
},
|
||||
{
|
||||
name: "clientPort",
|
||||
args: args{
|
||||
clientPort: 443,
|
||||
fieldMaskPaths: []string{"ClientPort"},
|
||||
},
|
||||
wantErr: true,
|
||||
wantRowsUpdate: 0,
|
||||
wantIsError: db.ErrInvalidFieldMask,
|
||||
},
|
||||
{
|
||||
name: "backendAddress",
|
||||
args: args{
|
||||
backendAddress: "127.0.0.1",
|
||||
fieldMaskPaths: []string{"BackendAddress"},
|
||||
},
|
||||
wantErr: true,
|
||||
wantRowsUpdate: 0,
|
||||
wantIsError: db.ErrInvalidFieldMask,
|
||||
},
|
||||
{
|
||||
name: "backendPort",
|
||||
args: args{
|
||||
backendPort: 4443,
|
||||
fieldMaskPaths: []string{"BackendPort"},
|
||||
},
|
||||
wantErr: true,
|
||||
wantRowsUpdate: 0,
|
||||
wantIsError: db.ErrInvalidFieldMask,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
|
||||
c := TestConnection(t, conn, session.PublicId, "0.0.0.0", 22, "127.0.0.1", 2222)
|
||||
|
||||
updateConnection := AllocConnection()
|
||||
updateConnection.PublicId = c.PublicId
|
||||
if tt.args.publicId != nil {
|
||||
updateConnection.PublicId = *tt.args.publicId
|
||||
}
|
||||
updateConnection.BytesUp = tt.args.bytesUp
|
||||
updateConnection.BytesDown = tt.args.bytesDown
|
||||
updateConnection.ClosedReason = tt.args.closedReason.String()
|
||||
updateConnection.Version = c.Version
|
||||
afterUpdate, afterUpdateState, updatedRows, err := repo.UpdateConnection(context.Background(), &updateConnection, updateConnection.Version, tt.args.fieldMaskPaths, tt.args.opt...)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(err)
|
||||
if tt.wantIsError != nil {
|
||||
assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error: %s", err.Error())
|
||||
}
|
||||
assert.Nil(afterUpdate)
|
||||
assert.Nil(afterUpdateState)
|
||||
assert.Equal(0, updatedRows)
|
||||
err = db.TestVerifyOplog(t, rw, c.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second))
|
||||
assert.Error(err)
|
||||
assert.True(errors.Is(db.ErrRecordNotFound, err))
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.Equal(tt.wantRowsUpdate, updatedRows)
|
||||
require.NotNil(afterUpdate)
|
||||
require.NotNil(afterUpdateState)
|
||||
switch tt.name {
|
||||
case "valid-no-op":
|
||||
assert.Equal(c.UpdateTime, afterUpdate.UpdateTime)
|
||||
default:
|
||||
assert.NotEqual(c.UpdateTime, afterUpdate.UpdateTime)
|
||||
}
|
||||
found, foundStates, err := repo.LookupConnection(context.Background(), c.PublicId)
|
||||
require.NoError(err)
|
||||
assert.Equal(afterUpdate, found)
|
||||
dbassrt := dbassert.New(t, rw)
|
||||
if tt.args.bytesUp == 0 {
|
||||
dbassrt.IsNull(found, "BytesUp")
|
||||
}
|
||||
if tt.args.bytesDown == 0 {
|
||||
dbassrt.IsNull(found, "BytesDown")
|
||||
}
|
||||
if tt.args.closedReason == "" {
|
||||
dbassrt.IsNull(found, "ClosedReason")
|
||||
}
|
||||
assert.Equal(tt.args.closedReason.String(), found.ClosedReason)
|
||||
assert.Equal(tt.args.bytesUp, found.BytesUp)
|
||||
assert.Equal(tt.args.bytesDown, found.BytesDown)
|
||||
|
||||
err = db.TestVerifyOplog(t, rw, c.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second))
|
||||
assert.Error(err)
|
||||
|
||||
require.Equal(1, len(foundStates))
|
||||
assert.Equal(StatusConnected.String(), foundStates[0].Status)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRepository_DeleteConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
rw := db.New(conn)
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
kms := kms.TestKms(t, conn, wrapper)
|
||||
repo, err := NewRepository(rw, rw, kms)
|
||||
require.NoError(t, err)
|
||||
session := TestDefaultSession(t, conn, wrapper, iamRepo)
|
||||
|
||||
type args struct {
|
||||
connection *Connection
|
||||
opt []Option
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantRowsDeleted int
|
||||
wantErr bool
|
||||
wantErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
connection: TestConnection(t, conn, session.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222),
|
||||
},
|
||||
wantRowsDeleted: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no-public-id",
|
||||
args: args{
|
||||
connection: func() *Connection {
|
||||
c := AllocConnection()
|
||||
return &c
|
||||
}(),
|
||||
},
|
||||
wantRowsDeleted: 0,
|
||||
wantErr: true,
|
||||
wantErrMsg: "delete connection: missing public id invalid parameter",
|
||||
},
|
||||
{
|
||||
name: "not-found",
|
||||
args: args{
|
||||
connection: func() *Connection {
|
||||
c := AllocConnection()
|
||||
id, err := newConnectionId()
|
||||
require.NoError(t, err)
|
||||
c.PublicId = id
|
||||
return &c
|
||||
}(),
|
||||
},
|
||||
wantRowsDeleted: 0,
|
||||
wantErr: true,
|
||||
wantErrMsg: "delete connection: failed record not found for ",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
deletedRows, err := repo.DeleteConnection(context.Background(), tt.args.connection.PublicId, tt.args.opt...)
|
||||
if tt.wantErr {
|
||||
assert.Error(err)
|
||||
assert.Equal(0, deletedRows)
|
||||
assert.Contains(err.Error(), tt.wantErrMsg)
|
||||
err = db.TestVerifyOplog(t, rw, tt.args.connection.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_DELETE), db.WithCreateNotBefore(10*time.Second))
|
||||
assert.Error(err)
|
||||
assert.True(errors.Is(db.ErrRecordNotFound, err))
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
assert.Equal(tt.wantRowsDeleted, deletedRows)
|
||||
found, _, err := repo.LookupConnection(context.Background(), tt.args.connection.PublicId)
|
||||
assert.NoError(err)
|
||||
assert.Nil(found)
|
||||
|
||||
err = db.TestVerifyOplog(t, rw, tt.args.connection.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_DELETE), db.WithCreateNotBefore(10*time.Second))
|
||||
assert.Error(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in new issue