diff --git a/internal/session/repository_connection.go b/internal/session/repository_connection.go new file mode 100644 index 0000000000..b76a82fb5b --- /dev/null +++ b/internal/session/repository_connection.go @@ -0,0 +1,297 @@ +package session + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/hashicorp/boundary/internal/db" + dbcommon "github.com/hashicorp/boundary/internal/db/common" +) + +// CreateConnection inserts into the repository and returns the new Connection with +// its State of "Pending". The following fields must be empty when creating a +// session: PublicId, BytesUp, BytesDown, ClosedReason, Version, CreateTime, +// UpdateTime. No options are currently supported. +func (r *Repository) CreateConnection(ctx context.Context, newConnection *Connection, opt ...Option) (*Connection, *ConnectionState, error) { + if newConnection == nil { + return nil, nil, fmt.Errorf("create connection: missing connection: %w", db.ErrInvalidParameter) + } + if newConnection.PublicId != "" { + return nil, nil, fmt.Errorf("create connection: public id is not empty: %w", db.ErrInvalidParameter) + } + if newConnection.BytesUp != 0 { + return nil, nil, fmt.Errorf("create connection: bytes down is not empty: %w", db.ErrInvalidParameter) + } + if newConnection.BytesDown != 0 { + return nil, nil, fmt.Errorf("create connection: bytes up is not empty: %w", db.ErrInvalidParameter) + } + if newConnection.ClosedReason != "" { + return nil, nil, fmt.Errorf("create connection: closed reason is not empty: %w", db.ErrInvalidParameter) + } + if newConnection.Version != 0 { + return nil, nil, fmt.Errorf("create connection: version is not empty: %w", db.ErrInvalidParameter) + } + if newConnection.CreateTime != nil { + return nil, nil, fmt.Errorf("create connection: create time is not empty: %w", db.ErrInvalidParameter) + } + if newConnection.UpdateTime != nil { + return nil, nil, fmt.Errorf("create connection: update time is not empty: %w", db.ErrInvalidParameter) + } + + id, err := newConnectionId() + if err != nil { + return nil, nil, fmt.Errorf("create connection: %w", err) + } + newConnection.PublicId = id + + var returnedConnection *Connection + var returnedState *ConnectionState + _, err = r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(read db.Reader, w db.Writer) error { + returnedConnection = newConnection.Clone().(*Connection) + if err = w.Create(ctx, returnedConnection); err != nil { + return err + } + var foundStates []*ConnectionState + // trigger will create new "Pending" state + if foundStates, err = fetchConnectionStates(ctx, read, returnedConnection.PublicId); err != nil { + return err + } + if len(foundStates) != 1 { + return fmt.Errorf("%d states found for new connection %s", len(foundStates), returnedConnection.PublicId) + } + returnedState = foundStates[0] + if returnedState.Status != StatusConnected.String() { + return fmt.Errorf("new connection %s state is not valid: %s", returnedConnection.PublicId, returnedState.Status) + } + return nil + }, + ) + if err != nil { + return nil, nil, fmt.Errorf("create connection: %w", err) + } + return returnedConnection, returnedState, err +} + +// LookupConnection will look up a connection in the repository and return the connection +// with its states. If the connection is not found, it will return nil, nil, nil. +// No options are currently supported. +func (r *Repository) LookupConnection(ctx context.Context, connectionId string, opt ...Option) (*Connection, []*ConnectionState, error) { + if connectionId == "" { + return nil, nil, fmt.Errorf("lookup connection: missing connectionId id: %w", db.ErrInvalidParameter) + } + connection := AllocConnection() + connection.PublicId = connectionId + var states []*ConnectionState + _, err := r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(read db.Reader, w db.Writer) error { + if err := read.LookupById(ctx, &connection); err != nil { + return fmt.Errorf("lookup connection: failed %w for %s", err, connectionId) + } + var err error + if states, err = fetchConnectionStates(ctx, read, connectionId, db.WithOrder("start_time desc")); err != nil { + return err + } + return nil + }, + ) + if err != nil { + if errors.Is(err, db.ErrRecordNotFound) { + return nil, nil, nil + } + return nil, nil, fmt.Errorf("lookup connection: %w", err) + } + return &connection, states, nil +} + +// ListConnections will sessions. Supports the WithLimit and WithOrder options. +func (r *Repository) ListConnections(ctx context.Context, sessionId string, opt ...Option) ([]*Connection, error) { + var connections []*Connection + err := r.list(ctx, &connections, "session_id = ?", []interface{}{sessionId}, opt...) // pass options, so WithLimit and WithOrder are supported + if err != nil { + return nil, fmt.Errorf("list connections: %w", err) + } + return connections, nil +} + +// DeleteConnection will delete a connection from the repository. +func (r *Repository) DeleteConnection(ctx context.Context, publicId string, opt ...Option) (int, error) { + if publicId == "" { + return db.NoRowsAffected, fmt.Errorf("delete connection: missing public id %w", db.ErrInvalidParameter) + } + connection := AllocConnection() + connection.PublicId = publicId + if err := r.reader.LookupByPublicId(ctx, &connection); err != nil { + return db.NoRowsAffected, fmt.Errorf("delete connection: failed %w for %s", err, publicId) + } + + var rowsDeleted int + _, err := r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(_ db.Reader, w db.Writer) error { + deleteSession := connection.Clone() + var err error + rowsDeleted, err = w.Delete( + ctx, + deleteSession, + ) + if err == nil && rowsDeleted > 1 { + // return err, which will result in a rollback of the delete + return errors.New("error more than 1 connection would have been deleted") + } + return err + }, + ) + if err != nil { + return db.NoRowsAffected, fmt.Errorf("delete connection: failed %w for %s", err, publicId) + } + return rowsDeleted, nil +} + +// UpdateConnection updates the repository entry for the connection, using the +// fieldMaskPaths. Only BytesUp, BytesDown, and ClosedReason are muttable and +// will be set to NULL if set to a zero value and included in the fieldMaskPaths. +func (r *Repository) UpdateConnection(ctx context.Context, connection *Connection, version uint32, fieldMaskPaths []string, opt ...Option) (*Connection, []*ConnectionState, int, error) { + if connection == nil { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update connection: missing connection %w", db.ErrInvalidParameter) + } + if connection.PublicId == "" { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update connection: missing connection public id %w", db.ErrInvalidParameter) + } + for _, f := range fieldMaskPaths { + switch { + case strings.EqualFold("BytesUp", f): + case strings.EqualFold("BytesDown", f): + case strings.EqualFold("ClosedReason", f): + default: + return nil, nil, db.NoRowsAffected, fmt.Errorf("update connection: field: %s: %w", f, db.ErrInvalidFieldMask) + } + } + var dbMask, nullFields []string + dbMask, nullFields = dbcommon.BuildUpdatePaths( + map[string]interface{}{ + "BytesUp": connection.BytesUp, + "BytesDown": connection.BytesDown, + "ClosedReason": connection.ClosedReason, + }, + fieldMaskPaths, + ) + if len(dbMask) == 0 && len(nullFields) == 0 { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update connection: %w", db.ErrEmptyFieldMask) + } + + var c *Connection + var states []*ConnectionState + var rowsUpdated int + _, err := r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + var err error + c = connection.Clone().(*Connection) + rowsUpdated, err = w.Update( + ctx, + c, + dbMask, + nullFields, + ) + if err != nil { + return err + } + if err == nil && rowsUpdated > 1 { + // return err, which will result in a rollback of the update + return errors.New("error more than 1 connection would have been updated ") + } + states, err = fetchConnectionStates(ctx, reader, c.PublicId, db.WithOrder("start_time desc")) + if err != nil { + return err + } + return nil + }, + ) + if err != nil { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update connection: %w for %s", err, connection.PublicId) + } + return c, states, rowsUpdated, err +} + +// UpdateConnectionState will update the connection's state using the connection id and its +// version. No options are currently supported. +func (r *Repository) UpdateConnectionState(ctx context.Context, connectionId string, connectionVersion uint32, s ConnectionStatus, opt ...Option) (*Connection, []*ConnectionState, error) { + if connectionId == "" { + return nil, nil, fmt.Errorf("update connection state: missing session id %w", db.ErrInvalidParameter) + } + if connectionVersion == 0 { + return nil, nil, fmt.Errorf("update connection state: version cannot be zero: %w", db.ErrInvalidParameter) + } + if s == "" { + return nil, nil, fmt.Errorf("update connection state: missing connection status: %w", db.ErrInvalidParameter) + } + + newState, err := NewConnectionState(connectionId, s) + if err != nil { + return nil, nil, fmt.Errorf("update connection state: %w", err) + } + sessionConnection, _, err := r.LookupConnection(ctx, connectionId) + if err != nil { + return nil, nil, fmt.Errorf("update connection state: %w", err) + } + if sessionConnection == nil { + return nil, nil, fmt.Errorf("update connection state: unable to look up connection for %s: %w", connectionId, err) + } + + updatedConnection := AllocConnection() + var returnedStates []*ConnectionState + _, err = r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + // We need to update the session version as that's the aggregate + updatedConnection.PublicId = connectionId + updatedConnection.Version = uint32(connectionVersion) + 1 + rowsUpdated, err := w.Update(ctx, &updatedConnection, []string{"Version"}, nil, db.WithVersion(&connectionVersion)) + if err != nil { + return fmt.Errorf("unable to update connection version: %w", err) + } + if rowsUpdated != 1 { + return fmt.Errorf("updated connection and %d rows updated", rowsUpdated) + } + if err := w.Create(ctx, newState); err != nil { + return fmt.Errorf("unable to add new state: %w", err) + } + + returnedStates, err = fetchConnectionStates(ctx, reader, connectionId, db.WithOrder("start_time desc")) + if err != nil { + return err + } + return nil + }, + ) + if err != nil { + return nil, nil, fmt.Errorf("update connection state: error creating new state: %w", err) + } + return &updatedConnection, returnedStates, nil +} + +func fetchConnectionStates(ctx context.Context, r db.Reader, connectionId string, opt ...db.Option) ([]*ConnectionState, error) { + var states []*ConnectionState + if err := r.SearchWhere(ctx, &states, "connection_id = ?", []interface{}{connectionId}, opt...); err != nil { + return nil, fmt.Errorf("fetch connection states: %w", err) + } + if len(states) == 0 { + return nil, nil + } + return states, nil +} diff --git a/internal/session/repository_connection_test.go b/internal/session/repository_connection_test.go new file mode 100644 index 0000000000..0f983e7aef --- /dev/null +++ b/internal/session/repository_connection_test.go @@ -0,0 +1,631 @@ +package session + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/golang/protobuf/ptypes" + "github.com/hashicorp/boundary/internal/db" + dbassert "github.com/hashicorp/boundary/internal/db/assert" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/oplog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRepository_ListConnection(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + const testLimit = 10 + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + rw := db.New(conn) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms, WithLimit(testLimit)) + require.NoError(t, err) + session := TestDefaultSession(t, conn, wrapper, iamRepo) + + type args struct { + searchForSessionId string + opt []Option + } + tests := []struct { + name string + createCnt int + args args + wantCnt int + wantErr bool + }{ + { + name: "no-limit", + createCnt: repo.defaultLimit + 1, + args: args{ + searchForSessionId: session.PublicId, + opt: []Option{WithLimit(-1)}, + }, + wantCnt: repo.defaultLimit + 1, + wantErr: false, + }, + { + name: "default-limit", + createCnt: repo.defaultLimit + 1, + args: args{ + searchForSessionId: session.PublicId, + }, + wantCnt: repo.defaultLimit, + wantErr: false, + }, + { + name: "custom-limit", + createCnt: repo.defaultLimit + 1, + args: args{ + searchForSessionId: session.PublicId, + opt: []Option{WithLimit(3)}, + }, + wantCnt: 3, + wantErr: false, + }, + { + name: "bad-session-id", + createCnt: repo.defaultLimit + 1, + args: args{ + searchForSessionId: "s_thisIsNotValid", + }, + wantCnt: 0, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + require.NoError(conn.Where("1=1").Delete(AllocConnection()).Error) + testConnections := []*Connection{} + for i := 0; i < tt.createCnt; i++ { + c := TestConnection(t, conn, + session.PublicId, + "127.0.0.1", + 22, + "127.0.0.1", + 2222, + ) + testConnections = append(testConnections, c) + } + assert.Equal(tt.createCnt, len(testConnections)) + got, err := repo.ListConnections(context.Background(), tt.args.searchForSessionId, tt.args.opt...) + if tt.wantErr { + require.Error(err) + return + } + require.NoError(err) + assert.Equal(tt.wantCnt, len(got)) + }) + } + t.Run("withOrder", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + require.NoError(conn.Where("1=1").Delete(AllocConnection()).Error) + wantCnt := 5 + for i := 0; i < wantCnt; i++ { + _ = TestConnection(t, conn, + session.PublicId, + "127.0.0.1", + 22, + "127.0.0.1", + 2222, + ) + } + got, err := repo.ListConnections(context.Background(), session.PublicId, WithOrder("create_time asc")) + require.NoError(err) + assert.Equal(wantCnt, len(got)) + + for i := 0; i < len(got)-1; i++ { + first, err := ptypes.Timestamp(got[i].CreateTime.Timestamp) + require.NoError(err) + second, err := ptypes.Timestamp(got[i+1].CreateTime.Timestamp) + require.NoError(err) + assert.True(first.Before(second)) + } + }) +} + +func TestRepository_CreateConnection(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(t, err) + session := TestDefaultSession(t, conn, wrapper, iamRepo) + + type args struct { + connection *Connection + } + tests := []struct { + name string + args args + wantErr bool + wantIsError error + }{ + { + name: "valid", + args: args{ + connection: func() *Connection { + c, err := NewConnection( + session.PublicId, + "127.0.0.1", + 22, + "127.0.0.1", + 2222, + ) + require.NoError(t, err) + return c + }(), + }, + wantErr: false, + }, + { + name: "empty-session-id", + args: args{ + connection: &Connection{ + ClientAddress: "127.0.0.1", + ClientPort: 22, + BackendAddress: "127.0.0.1", + BackendPort: 2222, + }, + }, + wantErr: true, + wantIsError: db.ErrInvalidParameter, + }, + { + name: "empty-client-address", + args: args{ + connection: &Connection{ + SessionId: session.PublicId, + ClientPort: 22, + BackendAddress: "127.0.0.1", + BackendPort: 2222, + }, + }, + wantErr: true, + wantIsError: db.ErrInvalidParameter, + }, + { + name: "empty-client-port", + args: args{ + connection: &Connection{ + SessionId: session.PublicId, + ClientAddress: "127.0.0.1", + BackendAddress: "127.0.0.1", + BackendPort: 2222, + }, + }, + wantErr: true, + wantIsError: db.ErrInvalidParameter, + }, + { + name: "empty-backend-address", + args: args{ + connection: &Connection{ + SessionId: session.PublicId, + ClientAddress: "127.0.0.1", + ClientPort: 22, + BackendPort: 2222, + }, + }, + wantErr: true, + wantIsError: db.ErrInvalidParameter, + }, + { + name: "empty-backend-port", + args: args{ + connection: &Connection{ + SessionId: session.PublicId, + ClientAddress: "127.0.0.1", + ClientPort: 22, + BackendAddress: "127.0.0.1", + }, + }, + wantErr: true, + wantIsError: db.ErrInvalidParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + connection, st, err := repo.CreateConnection(context.Background(), tt.args.connection) + if tt.wantErr { + assert.Error(err) + assert.Nil(connection) + assert.Nil(st) + if tt.wantIsError != nil { + assert.True(errors.Is(err, tt.wantIsError)) + } + return + } + require.NoError(err) + assert.NotNil(connection.CreateTime) + assert.NotNil(st.StartTime) + assert.Equal(st.Status, StatusConnected.String()) + found, foundStates, err := repo.LookupConnection(context.Background(), connection.PublicId) + assert.NoError(err) + assert.Equal(found, connection) + + err = db.TestVerifyOplog(t, rw, connection.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_CREATE), db.WithCreateNotBefore(10*time.Second)) + assert.Error(err) + + require.Equal(1, len(foundStates)) + assert.Equal(foundStates[0].Status, StatusConnected.String()) + }) + } +} + +func TestRepository_UpdateConnectionState(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(t, err) + session := TestDefaultSession(t, conn, wrapper, iamRepo) + + tests := []struct { + name string + connection *Connection + newStatus ConnectionStatus + overrideConnectionId *string + overrideConnectionVersion *uint32 + wantStateCnt int + wantErr bool + wantIsError error + }{ + { + name: "closed", + connection: TestConnection(t, conn, session.PublicId, "0.0.0.0", 22, "0.0.0.0", 2222), + newStatus: StatusClosed, + wantStateCnt: 2, + wantErr: false, + }, + { + name: "bad-version", + connection: TestConnection(t, conn, session.PublicId, "0.0.0.0", 22, "0.0.0.0", 2222), + newStatus: StatusClosed, + overrideConnectionVersion: func() *uint32 { + v := uint32(22) + return &v + }(), + wantErr: true, + }, + { + name: "empty-version", + connection: TestConnection(t, conn, session.PublicId, "0.0.0.0", 22, "0.0.0.0", 2222), + newStatus: StatusClosed, + overrideConnectionVersion: func() *uint32 { + v := uint32(0) + return &v + }(), + wantErr: true, + wantIsError: db.ErrInvalidParameter, + }, + { + name: "bad-connectionId", + connection: TestConnection(t, conn, session.PublicId, "0.0.0.0", 22, "0.0.0.0", 2222), + newStatus: StatusClosed, + overrideConnectionId: func() *string { + s := "sc_thisIsNotValid" + return &s + }(), + wantErr: true, + }, + { + name: "empty-connectionId", + connection: TestConnection(t, conn, session.PublicId, "0.0.0.0", 22, "0.0.0.0", 2222), + newStatus: StatusClosed, + overrideConnectionId: func() *string { + s := "" + return &s + }(), + wantErr: true, + wantIsError: db.ErrInvalidParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + var id string + var version uint32 + switch { + case tt.overrideConnectionId != nil: + id = *tt.overrideConnectionId + default: + id = tt.connection.PublicId + } + switch { + case tt.overrideConnectionVersion != nil: + version = *tt.overrideConnectionVersion + default: + version = tt.connection.Version + } + + s, ss, err := repo.UpdateConnectionState(context.Background(), id, version, tt.newStatus) + if tt.wantErr { + require.Error(err) + if tt.wantIsError != nil { + assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error %s", err.Error()) + } + return + } + require.NoError(err) + require.NotNil(s) + require.NotNil(ss) + assert.Equal(tt.wantStateCnt, len(ss)) + assert.Equal(tt.newStatus.String(), ss[0].Status) + }) + } +} + +func TestRepository_UpdateConnection(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(t, err) + session := TestDefaultSession(t, conn, wrapper, iamRepo) + + type args struct { + closedReason ClosedReason + bytesUp uint64 + bytesDown uint64 + fieldMaskPaths []string + opt []Option + publicId *string // not updateable - db.ErrInvalidFieldMask + sessionId string // not updateable - db.ErrInvalidFieldMask + clientAddress string // not updateable - db.ErrInvalidFieldMask + clientPort uint32 // not updateable - db.ErrInvalidFieldMask + backendAddress string // not updateable - db.ErrInvalidFieldMask + backendPort uint32 // not updateable - db.ErrInvalidFieldMask + } + tests := []struct { + name string + args args + wantRowsUpdate int + wantErr bool + wantIsError error + }{ + { + name: "valid", + args: args{ + closedReason: ConnectionClosedByUser, + bytesUp: uint64(111), + bytesDown: uint64(1), + fieldMaskPaths: []string{"ClosedReason", "BytesUp", "BytesDown"}, + }, + wantErr: false, + wantRowsUpdate: 1, + }, + { + name: "publicId", + args: args{ + publicId: func() *string { + id, err := newConnectionId() + require.NoError(t, err) + return &id + }(), + fieldMaskPaths: []string{"PublicId"}, + }, + wantErr: true, + wantRowsUpdate: 0, + wantIsError: db.ErrInvalidFieldMask, + }, + { + name: "sessionId", + args: args{ + sessionId: func() string { + id, err := newId() + require.NoError(t, err) + return id + }(), + fieldMaskPaths: []string{"SessionId"}, + }, + wantErr: true, + wantRowsUpdate: 0, + wantIsError: db.ErrInvalidFieldMask, + }, + { + name: "clientAddress", + args: args{ + clientAddress: "127.0.0.1", + fieldMaskPaths: []string{"ClientAddress"}, + }, + wantErr: true, + wantRowsUpdate: 0, + wantIsError: db.ErrInvalidFieldMask, + }, + { + name: "clientPort", + args: args{ + clientPort: 443, + fieldMaskPaths: []string{"ClientPort"}, + }, + wantErr: true, + wantRowsUpdate: 0, + wantIsError: db.ErrInvalidFieldMask, + }, + { + name: "backendAddress", + args: args{ + backendAddress: "127.0.0.1", + fieldMaskPaths: []string{"BackendAddress"}, + }, + wantErr: true, + wantRowsUpdate: 0, + wantIsError: db.ErrInvalidFieldMask, + }, + { + name: "backendPort", + args: args{ + backendPort: 4443, + fieldMaskPaths: []string{"BackendPort"}, + }, + wantErr: true, + wantRowsUpdate: 0, + wantIsError: db.ErrInvalidFieldMask, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + c := TestConnection(t, conn, session.PublicId, "0.0.0.0", 22, "127.0.0.1", 2222) + + updateConnection := AllocConnection() + updateConnection.PublicId = c.PublicId + if tt.args.publicId != nil { + updateConnection.PublicId = *tt.args.publicId + } + updateConnection.BytesUp = tt.args.bytesUp + updateConnection.BytesDown = tt.args.bytesDown + updateConnection.ClosedReason = tt.args.closedReason.String() + updateConnection.Version = c.Version + afterUpdate, afterUpdateState, updatedRows, err := repo.UpdateConnection(context.Background(), &updateConnection, updateConnection.Version, tt.args.fieldMaskPaths, tt.args.opt...) + + if tt.wantErr { + require.Error(err) + if tt.wantIsError != nil { + assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error: %s", err.Error()) + } + assert.Nil(afterUpdate) + assert.Nil(afterUpdateState) + assert.Equal(0, updatedRows) + err = db.TestVerifyOplog(t, rw, c.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second)) + assert.Error(err) + assert.True(errors.Is(db.ErrRecordNotFound, err)) + return + } + require.NoError(err) + assert.Equal(tt.wantRowsUpdate, updatedRows) + require.NotNil(afterUpdate) + require.NotNil(afterUpdateState) + switch tt.name { + case "valid-no-op": + assert.Equal(c.UpdateTime, afterUpdate.UpdateTime) + default: + assert.NotEqual(c.UpdateTime, afterUpdate.UpdateTime) + } + found, foundStates, err := repo.LookupConnection(context.Background(), c.PublicId) + require.NoError(err) + assert.Equal(afterUpdate, found) + dbassrt := dbassert.New(t, rw) + if tt.args.bytesUp == 0 { + dbassrt.IsNull(found, "BytesUp") + } + if tt.args.bytesDown == 0 { + dbassrt.IsNull(found, "BytesDown") + } + if tt.args.closedReason == "" { + dbassrt.IsNull(found, "ClosedReason") + } + assert.Equal(tt.args.closedReason.String(), found.ClosedReason) + assert.Equal(tt.args.bytesUp, found.BytesUp) + assert.Equal(tt.args.bytesDown, found.BytesDown) + + err = db.TestVerifyOplog(t, rw, c.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second)) + assert.Error(err) + + require.Equal(1, len(foundStates)) + assert.Equal(StatusConnected.String(), foundStates[0].Status) + }) + } + +} + +func TestRepository_DeleteConnection(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(t, err) + session := TestDefaultSession(t, conn, wrapper, iamRepo) + + type args struct { + connection *Connection + opt []Option + } + tests := []struct { + name string + args args + wantRowsDeleted int + wantErr bool + wantErrMsg string + }{ + { + name: "valid", + args: args{ + connection: TestConnection(t, conn, session.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222), + }, + wantRowsDeleted: 1, + wantErr: false, + }, + { + name: "no-public-id", + args: args{ + connection: func() *Connection { + c := AllocConnection() + return &c + }(), + }, + wantRowsDeleted: 0, + wantErr: true, + wantErrMsg: "delete connection: missing public id invalid parameter", + }, + { + name: "not-found", + args: args{ + connection: func() *Connection { + c := AllocConnection() + id, err := newConnectionId() + require.NoError(t, err) + c.PublicId = id + return &c + }(), + }, + wantRowsDeleted: 0, + wantErr: true, + wantErrMsg: "delete connection: failed record not found for ", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + deletedRows, err := repo.DeleteConnection(context.Background(), tt.args.connection.PublicId, tt.args.opt...) + if tt.wantErr { + assert.Error(err) + assert.Equal(0, deletedRows) + assert.Contains(err.Error(), tt.wantErrMsg) + err = db.TestVerifyOplog(t, rw, tt.args.connection.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_DELETE), db.WithCreateNotBefore(10*time.Second)) + assert.Error(err) + assert.True(errors.Is(db.ErrRecordNotFound, err)) + return + } + assert.NoError(err) + assert.Equal(tt.wantRowsDeleted, deletedRows) + found, _, err := repo.LookupConnection(context.Background(), tt.args.connection.PublicId) + assert.NoError(err) + assert.Nil(found) + + err = db.TestVerifyOplog(t, rw, tt.args.connection.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_DELETE), db.WithCreateNotBefore(10*time.Second)) + assert.Error(err) + }) + } +}