refact(session connection): remove session connection state table (#4617)

* refact(session connection): remove session connection state table
pull/4954/head^2
Irena Rindos 2 years ago committed by GitHub
parent eb6e1b558d
commit 2191aa1d03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -622,16 +622,13 @@ func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs
return nil, status.Errorf(codes.NotFound, "worker not found with name %q", req.GetWorkerId())
}
connectionInfo, connStates, err := connectionRepo.AuthorizeConnection(ctx, req.GetSessionId(), w.GetPublicId())
connectionInfo, err := connectionRepo.AuthorizeConnection(ctx, req.GetSessionId(), w.GetPublicId())
if err != nil {
return nil, err
}
if connectionInfo == nil {
return nil, status.Error(codes.Internal, "Invalid authorize connection response.")
}
if len(connStates) == 0 {
return nil, status.Error(codes.Internal, "Invalid connection state in authorize response.")
}
sessInfo, authzSummary, err := sessionRepo.LookupSession(ctx, req.GetSessionId())
if err != nil {
@ -648,7 +645,7 @@ func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs
ret := &pbs.AuthorizeConnectionResponse{
ConnectionId: connectionInfo.GetPublicId(),
Status: connStates[0].Status.ProtoVal(),
Status: session.ConnectionStatusFromString(connectionInfo.Status).ProtoVal(),
ConnectionsLeft: authzSummary.ConnectionLimit,
Route: route,
}
@ -680,7 +677,7 @@ func (ws *workerServiceServer) ConnectConnection(ctx context.Context, req *pbs.C
return nil, status.Errorf(codes.Internal, "error getting session repo: %v", err)
}
connectionInfo, connStates, err := connRepo.ConnectConnection(ctx, session.ConnectWith{
connectionInfo, err := connRepo.ConnectConnection(ctx, session.ConnectWith{
ConnectionId: req.GetConnectionId(),
ClientTcpAddress: req.GetClientTcpAddress(),
ClientTcpPort: req.GetClientTcpPort(),
@ -696,7 +693,7 @@ func (ws *workerServiceServer) ConnectConnection(ctx context.Context, req *pbs.C
}
return &pbs.ConnectConnectionResponse{
Status: connStates[0].Status.ProtoVal(),
Status: session.ConnectionStatusFromString(connectionInfo.Status).ProtoVal(),
}, nil
}
@ -742,12 +739,9 @@ func (ws *workerServiceServer) CloseConnection(ctx context.Context, req *pbs.Clo
if v.Connection == nil {
return nil, status.Errorf(codes.Internal, "No connection found while closing one of the connection IDs: %v", closeIds)
}
if len(v.ConnectionStates) == 0 {
return nil, status.Errorf(codes.Internal, "No connection states found while closing one of the connection IDs: %v", closeIds)
}
closeData = append(closeData, &pbs.CloseConnectionResponseData{
ConnectionId: v.Connection.GetPublicId(),
Status: v.ConnectionStates[0].Status.ProtoVal(),
Status: v.ConnectionState.ProtoVal(),
})
}

@ -97,7 +97,7 @@ func TestStatus(t *testing.T) {
tofu := session.TestTofu(t)
canceledSess, _, err = repo.ActivateSession(ctx, canceledSess.PublicId, canceledSess.Version, tofu)
require.NoError(t, err)
canceledConn, _, err := connRepo.AuthorizeConnection(ctx, canceledSess.PublicId, worker1.PublicId)
canceledConn, err := connRepo.AuthorizeConnection(ctx, canceledSess.PublicId, worker1.PublicId)
require.NoError(t, err)
canceledSess, err = repo.CancelSession(ctx, canceledSess.PublicId, canceledSess.Version)
@ -120,7 +120,7 @@ func TestStatus(t *testing.T) {
s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce)
require.NotNil(t, s)
connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
require.NoError(t, err)
cases := []struct {
@ -562,7 +562,7 @@ func TestStatusSessionClosed(t *testing.T) {
s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce)
require.NotNil(t, s)
connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
require.NoError(t, err)
cases := []struct {
@ -757,9 +757,9 @@ func TestStatusDeadConnection(t *testing.T) {
s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce)
require.NotNil(t, s)
connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
require.NoError(t, err)
deadConn, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker1.PublicId)
deadConn, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker1.PublicId)
require.NoError(t, err)
require.NotEqual(t, deadConn.PublicId, connection.PublicId)
@ -823,12 +823,10 @@ func TestStatusDeadConnection(t *testing.T) {
),
)
gotConn, states, err := connRepo.LookupConnection(ctx, deadConn.PublicId)
gotConn, err := connRepo.LookupConnection(ctx, deadConn.PublicId)
require.NoError(t, err)
assert.Equal(t, session.ConnectionSystemError, session.ClosedReason(gotConn.ClosedReason))
assert.Equal(t, 2, len(states))
assert.Nil(t, states[0].EndTime)
assert.Equal(t, session.StatusClosed, states[0].Status)
assert.Equal(t, session.StatusClosed, session.ConnectionStatusFromString(gotConn.Status))
}
func TestStatusWorkerWithKeyId(t *testing.T) {
@ -927,7 +925,7 @@ func TestStatusWorkerWithKeyId(t *testing.T) {
s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce)
require.NotNil(t, s)
connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId)
require.NoError(t, err)
cases := []struct {

@ -219,6 +219,7 @@ begin;
create trigger insert_new_session_state after insert on session
for each row execute procedure insert_new_session_state();
-- Updated in 90/01_remove_session_connection_state
-- update_connection_state_on_closed_reason() is used in an update insert trigger on the
-- session_connection table. it will valiadate that all the session's
-- connections are closed, and then insert a state of "closed" in

@ -148,6 +148,7 @@ begin;
create trigger default_create_time_column before insert on session_connection
for each row execute procedure default_create_time();
-- Removed in 90/01_remove_session_connection_state.up.sql
-- insert_new_connection_state() is used in an after insert trigger on the
-- session_connection table. it will insert a state of "authorized" in
-- session_connection_state for the new session connection.

@ -404,6 +404,7 @@ begin;
drop trigger wh_insert_session_connection_state on session_connection_state;
drop function wh_insert_session_connection_state;
-- Updated in 90/01_remove_session_connection_state.up.sql
create function wh_insert_session_connection_state() returns trigger
as $$
declare

@ -7,6 +7,7 @@ begin;
drop trigger update_connection_state_on_closed_reason on session_connection;
drop function update_connection_state_on_closed_reason();
-- Removed in 90/01_remove_session_connection_state.up.sql
create function update_connection_state_on_closed_reason() returns trigger
as $$
begin

@ -7,6 +7,7 @@ begin;
drop trigger wh_insert_session_connection on session_connection;
drop function wh_insert_session_connection();
-- Updated in 90/01_remove_session_connection_state
create function wh_insert_session_connection() returns trigger
as $$
declare

@ -0,0 +1,247 @@
-- Copyright (c) HashiCorp, Inc.
-- SPDX-License-Identifier: BUSL-1.1
begin;
-- Remove the session_connection_state table and any related triggers
drop trigger update_connection_state_on_closed_reason on session_connection;
drop function update_connection_state_on_closed_reason();
drop trigger insert_session_connection_state on session_connection_state;
drop function insert_session_connection_state();
drop trigger update_session_state_on_termination_reason on session;
drop function update_session_state_on_termination_reason();
drop trigger insert_new_connection_state on session_connection;
drop function insert_new_connection_state();
drop trigger immutable_columns on session_connection_state;
drop trigger wh_insert_session_connection_state on session_connection_state;
drop function wh_insert_session_connection_state();
drop trigger wh_insert_session_connection on session_connection;
drop function wh_insert_session_connection();
-- If the connected_time_range is null, it means the connection is authorized but not connected.
-- If the upper value of connected_time_range is > now() (upper range is infinity) then the state is connected.
-- If the upper value of connected_time_range is <= now() then the connection is closed.
alter table session_connection
add column connected_time_range tstzrange;
-- Migrate existing data from session_connection_state to session_connection
update session_connection
set connected_time_range = (select tstzrange(min(start_time), max(start_time))
from session_connection_state
where session_connection_state.connection_id = session_connection.public_id
group by connection_id );
drop table session_connection_state;
drop table session_connection_state_enm;
-- Insert on session_connection creates the connection entry, leaving the connected_time_range to null, indicating the connection is authorized
-- "Connected" is handled by the function ConnectConnection, which sets the connected_time_range lower bound to now() and upper bound to infinity
-- "Closed" is handled by the trigger function, update_connected_time_range_on_closed_reason, which sets the connected_time_range upper bound to now()
-- State transitions are guarded by the trigger function, check_connection_state_transition, which ensures that the state transitions are valid
create function check_connection_state_transition() returns trigger
as $$
begin
-- If old state was authorized, allow transition to connected or closed
if old.connected_time_range is null then
return new;
end if;
-- If old state was closed, no transitions are allowed
if upper(old.connected_time_range) < 'infinity' and old.connected_time_range != new.connected_time_range then
raise exception 'Invalid state transition from closed';
end if;
-- If old state was connected, allow transition to closed
if upper(old.connected_time_range) = 'infinity' and
upper(new.connected_time_range) != 'infinity' and
lower(old.connected_time_range) = lower(new.connected_time_range) then
return new;
else
raise exception 'Invalid state transition from connected';
end if;
return new;
end;
$$ language plpgsql;
create trigger check_connection_state_transition before update of connected_time_range on session_connection
for each row execute procedure check_connection_state_transition();
create function update_connected_time_range_on_closed_reason() returns trigger
as $$
begin
if new.closed_reason is not null then
if old.connected_time_range is null or upper(old.connected_time_range) = 'infinity'::timestamptz then
new.connected_time_range = tstzrange(lower(old.connected_time_range), now(), '[]');
end if;
end if;
return new;
end;
$$ language plpgsql;
create trigger update_connected_time_range_closed_reason before update of closed_reason on session_connection
for each row execute procedure update_connected_time_range_on_closed_reason();
create function update_session_state_on_termination_reason() returns trigger
as $$
begin
if new.termination_reason is not null then
perform
from session_connection
where session_id = new.public_id
and upper(connected_time_range) = 'infinity'::timestamptz;
if found then
raise 'session %s has open connections', new.public_id;
end if;
-- check to see if there's a terminated state already, before inserting a
-- new one.
perform
from session_state ss
where ss.session_id = new.public_id and
ss.state = 'terminated';
if found then
return new;
end if;
insert into session_state (session_id, state)
values (new.public_id, 'terminated');
end if;
return new;
end;
$$ language plpgsql;
create trigger update_session_state_on_termination_reason after update of termination_reason on session
for each row execute procedure update_session_state_on_termination_reason();
create function wh_insert_session_connection() returns trigger
as $$
declare
new_row wh_session_connection_accumulating_fact%rowtype;
begin
with
authorized_timestamp (date_dim_key, time_dim_key, ts) as (
select wh_date_key(create_time), wh_time_key(create_time), create_time
from session_connection
where public_id = new.public_id
and connected_time_range is null
),
session_dimension (host_dim_key, user_dim_key, credential_group_dim_key) as (
select host_key, user_key, credential_group_key
from wh_session_accumulating_fact
where session_id = new.session_id
)
insert into wh_session_connection_accumulating_fact (
connection_id,
session_id,
host_key,
user_key,
credential_group_key,
connection_authorized_date_key,
connection_authorized_time_key,
connection_authorized_time,
client_tcp_address,
client_tcp_port_number,
endpoint_tcp_address,
endpoint_tcp_port_number,
bytes_up,
bytes_down
)
select new.public_id,
new.session_id,
session_dimension.host_dim_key,
session_dimension.user_dim_key,
session_dimension.credential_group_dim_key,
authorized_timestamp.date_dim_key,
authorized_timestamp.time_dim_key,
authorized_timestamp.ts,
new.client_tcp_address,
new.client_tcp_port,
new.endpoint_tcp_address,
new.endpoint_tcp_port,
new.bytes_up,
new.bytes_down
from authorized_timestamp,
session_dimension
returning * into strict new_row;
return null;
end;
$$ language plpgsql;
create trigger wh_insert_session_connection after insert on session_connection
for each row execute function wh_insert_session_connection();
create function wh_insert_session_connection_state() returns trigger
as $$
declare
state text;
date_col text;
time_col text;
ts_col text;
q text;
connection_row wh_session_connection_accumulating_fact%rowtype;
begin
if new.connected_time_range is null then
-- Indicates authorized connection. The update statement in this
-- trigger will fail for the authorized state because the row for the
-- session connection has not yet been inserted into the
-- wh_session_connection_accumulating_fact table.
return null;
end if;
if upper(new.connected_time_range) = 'infinity'::timestamptz then
update wh_session_connection_accumulating_fact
set (connection_connected_date_key,
connection_connected_time_key,
connection_connected_time) = (select wh_date_key(new.update_time),
wh_time_key(new.update_time),
new.update_time::timestamptz)
where connection_id = new.public_id;
else
update wh_session_connection_accumulating_fact
set (connection_closed_date_key,
connection_closed_time_key,
connection_closed_time) = (select wh_date_key(new.update_time),
wh_time_key(new.update_time),
new.update_time::timestamptz)
where connection_id = new.public_id;
end if;
return null;
end;
$$ language plpgsql;
create trigger wh_insert_session_connection_state after update of connected_time_range on session_connection
for each row execute function wh_insert_session_connection_state();
create view session_connection_with_status_view as
select public_id,
session_id,
client_tcp_address,
client_tcp_port,
endpoint_tcp_address,
endpoint_tcp_port,
bytes_up,
bytes_down,
closed_reason,
version,
create_time,
update_time,
user_client_ip,
worker_id,
case
when connected_time_range is null then 'authorized'
when upper(connected_time_range) > now() then 'connected'
else 'closed'
end as status
from session_connection;
create index connected_time_range_idx on session_connection (connected_time_range);
create index connected_time_range_upper_idx on session_connection (upper(connected_time_range));
commit;

@ -0,0 +1,35 @@
-- Copyright (c) HashiCorp, Inc.
-- SPDX-License-Identifier: BUSL-1.1
begin;
select plan(6);
-- Ensure session connection table is populated
select is(count(*), 2::bigint) from session_connection;
-- Check that both session connections are in the authorized state (null connected_time_range)
select is(count(*), 2::bigint) from session_connection where connected_time_range is null;
-- Connect one of the session connections
update session_connection
set connected_time_range=tstzrange(now(),'infinity')
where public_id = 's1c1___clare';
select is(count(*), 1::bigint) from session_connection where upper(connected_time_range) > now();
-- Close the other session connection
update session_connection
set closed_reason = 'unknown'
where public_id = 's2c1___clare';
select is(count(*), 1::bigint) from session_connection where upper(connected_time_range) <= now();
-- Attempt to connect the closed session connection, expect an error
select throws_ok($$ update session_connection
set connected_time_range = tstzrange(now(), 'infinity')
where public_id = 's2c1___clare'$$);
-- Still only 1 connected session
select is(count(*), 1::bigint) from session_connection where upper(connected_time_range) > now();
select * from finish();
rollback;

@ -12,7 +12,8 @@ begin;
update session_connection set
bytes_up = 10,
bytes_down = 5,
closed_reason = 'closed by end-user'
closed_reason = 'closed by end-user',
connected_time_range = tstzrange(now()::wh_timestamp, now()::wh_timestamp)
where public_id = 's1c1___clare';
select is(count(*), 2::bigint) from wh_session_connection_accumulating_fact;

@ -206,10 +206,10 @@ func TestLookupWorker(t *testing.T) {
sess := session.TestSession(t, conn, wrapper, composedOf, session.WithDbOpts(db.WithSkipVetForWrite(true)), session.WithExpirationTime(exp))
sess, _, err = sessRepo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, []byte("foo"))
require.NoError(t, err)
c, _, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), w.GetPublicId())
c, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), w.GetPublicId())
require.NoError(t, err)
require.NotNil(t, c)
c, _, err = connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), w.GetPublicId())
c, err = connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), w.GetPublicId())
require.NoError(t, err)
require.NotNil(t, c)
}
@ -220,7 +220,7 @@ func TestLookupWorker(t *testing.T) {
session.WithDbOpts(db.WithSkipVetForWrite(true)))
sess2, _, err = sessRepo.ActivateSession(ctx, sess2.GetPublicId(), sess2.Version, []byte("foo"))
require.NoError(t, err)
c, _, err := connRepo.AuthorizeConnection(ctx, sess2.GetPublicId(), w.GetPublicId())
c, err := connRepo.AuthorizeConnection(ctx, sess2.GetPublicId(), w.GetPublicId())
require.NoError(t, err)
require.NotNil(t, c)
}

@ -13,7 +13,7 @@ import (
)
const (
defaultConnectionTableName = "session_connection"
defaultConnectionTableName = "session_connection_with_status_view" // "session_connection"
)
// Connection contains information about session's connection to a target
@ -44,6 +44,8 @@ type Connection struct {
UpdateTime *timestamp.Timestamp `json:"update_time,omitempty" gorm:"default:current_timestamp"`
// Version of the connection
Version uint32 `json:"version,omitempty" gorm:"default:null"`
// Status is a field derived from connected_time_range
Status string `json:"status,omitempty" gorm:"default:null"`
tableName string `gorm:"-"`
}
@ -94,6 +96,7 @@ func (c *Connection) Clone() any {
BytesDown: c.BytesDown,
ClosedReason: c.ClosedReason,
Version: c.Version,
Status: c.Status,
}
if c.CreateTime != nil {
clone.CreateTime = &timestamp.Timestamp{

@ -4,20 +4,9 @@
package session
import (
"context"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"google.golang.org/protobuf/types/known/timestamppb"
workerpbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
)
const (
defaultConnectionStateTableName = "session_connection_state"
)
// ConnectionStatus of the connection's state
type ConnectionStatus string
@ -60,122 +49,14 @@ func ConnectionStatusFromProtoVal(s workerpbs.CONNECTIONSTATUS) ConnectionStatus
return StatusUnspecified
}
// ConnectionState of the state of the connection
type ConnectionState struct {
// ConnectionId is used to access the state via an API
ConnectionId string `json:"public_id,omitempty" gorm:"primary_key"`
// status of the connection
Status ConnectionStatus `protobuf:"bytes,20,opt,name=status,proto3" json:"status,omitempty" gorm:"column:state"`
// PreviousEndTime from the RDBMS
PreviousEndTime *timestamp.Timestamp `json:"previous_end_time,omitempty" gorm:"default:current_timestamp"`
// StartTime from the RDBMS
StartTime *timestamp.Timestamp `json:"start_time,omitempty" gorm:"default:current_timestamp;primary_key"`
// EndTime from the RDBMS
EndTime *timestamp.Timestamp `json:"end_time,omitempty" gorm:"default:current_timestamp"`
tableName string `gorm:"-"`
}
var (
_ Cloneable = (*ConnectionState)(nil)
_ db.VetForWriter = (*ConnectionState)(nil)
)
// NewConnectionState creates a new in memory connection state. No options
// are currently supported.
func NewConnectionState(ctx context.Context, connectionId string, state ConnectionStatus, _ ...Option) (*ConnectionState, error) {
const op = "session.NewConnectionState"
s := ConnectionState{
ConnectionId: connectionId,
Status: state,
}
if err := s.validate(ctx); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return &s, nil
}
// allocConnectionState will allocate a connection State
func allocConnectionState() ConnectionState {
return ConnectionState{}
}
// Clone creates a clone of the State
func (s *ConnectionState) Clone() any {
clone := &ConnectionState{
ConnectionId: s.ConnectionId,
Status: s.Status,
}
if s.PreviousEndTime != nil {
clone.PreviousEndTime = &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
Seconds: s.PreviousEndTime.Timestamp.Seconds,
Nanos: s.PreviousEndTime.Timestamp.Nanos,
},
}
}
if s.StartTime != nil {
clone.StartTime = &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
Seconds: s.StartTime.Timestamp.Seconds,
Nanos: s.StartTime.Timestamp.Nanos,
},
}
}
if s.EndTime != nil {
clone.EndTime = &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
Seconds: s.EndTime.Timestamp.Seconds,
Nanos: s.EndTime.Timestamp.Nanos,
},
}
}
return clone
}
// VetForWrite implements db.VetForWrite() interface and validates the state
// before it's written.
func (s *ConnectionState) VetForWrite(ctx context.Context, _ db.Reader, _ db.OpType, _ ...db.Option) error {
const op = "session.(ConnectionState).VetForWrite"
if err := s.validate(ctx); err != nil {
return errors.Wrap(ctx, err, op)
}
return nil
}
// TableName returns the tablename to override the default gorm table name
func (s *ConnectionState) TableName() string {
if s.tableName != "" {
return s.tableName
}
return defaultConnectionStateTableName
}
// SetTableName sets the tablename and satisfies the ReplayableMessage
// interface. If the caller attempts to set the name to "" the name will be
// reset to the default name.
func (s *ConnectionState) SetTableName(n string) {
s.tableName = n
}
// validate checks the session state
func (s *ConnectionState) validate(ctx context.Context) error {
const op = "session.(ConnectionState).validate"
if s.Status == "" {
return errors.New(ctx, errors.InvalidParameter, op, "missing status")
}
if s.ConnectionId == "" {
return errors.New(ctx, errors.InvalidParameter, op, "missing connection id")
}
if s.StartTime != nil {
return errors.New(ctx, errors.InvalidParameter, op, "start time is not settable")
}
if s.EndTime != nil {
return errors.New(ctx, errors.InvalidParameter, op, "end time is not settable")
}
if s.PreviousEndTime != nil {
return errors.New(ctx, errors.InvalidParameter, op, "previous end time is not settable")
func ConnectionStatusFromString(s string) ConnectionStatus {
switch s {
case "authorized":
return StatusAuthorized
case "connected":
return StatusConnected
case "closed":
return StatusClosed
}
return nil
return StatusUnspecified
}

@ -1,217 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package session
import (
"context"
"testing"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConnectionState_Create(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
session := TestDefaultSession(t, conn, wrapper, iamRepo)
connection := TestConnection(t, conn, session.PublicId, "127.0.0.1", 443, "127.0.0.1", 4443, "127.0.0.1")
type args struct {
connectionId string
status ConnectionStatus
}
tests := []struct {
name string
args args
want *ConnectionState
wantErr bool
wantIsErr errors.Code
create bool
wantCreateErr bool
}{
{
name: "valid",
args: args{
connectionId: connection.PublicId,
status: StatusClosed,
},
want: &ConnectionState{
ConnectionId: connection.PublicId,
Status: StatusClosed,
},
create: true,
},
{
name: "empty-connectionId",
args: args{
status: StatusClosed,
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
},
{
name: "empty-status",
args: args{
connectionId: connection.PublicId,
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
got, err := NewConnectionState(ctx, tt.args.connectionId, tt.args.status)
if tt.wantErr {
require.Error(err)
assert.True(errors.Match(errors.T(tt.wantIsErr), err))
return
}
require.NoError(err)
assert.Equal(tt.want, got)
if tt.create {
err = db.New(conn).Create(ctx, got)
if tt.wantCreateErr {
assert.Error(err)
return
} else {
assert.NoError(err)
}
}
})
}
}
func TestConnectionState_Delete(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
s := TestDefaultSession(t, conn, wrapper, iamRepo)
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1")
c2 := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1")
tests := []struct {
name string
state *ConnectionState
deleteConnectionStateId string
wantRowsDeleted int
wantErr bool
wantErrMsg string
}{
{
name: "valid",
state: TestConnectionState(t, conn, c.PublicId, StatusClosed),
wantErr: false,
wantRowsDeleted: 1,
},
{
name: "bad-id",
state: TestConnectionState(t, conn, c2.PublicId, StatusClosed),
deleteConnectionStateId: func() string {
id, err := db.NewPublicId(ctx, ConnectionStatePrefix)
require.NoError(t, err)
return id
}(),
wantErr: false,
wantRowsDeleted: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
var initialState ConnectionState
err := rw.LookupWhere(context.Background(), &initialState, "connection_id = ? and state = ?", []any{tt.state.ConnectionId, tt.state.Status})
require.NoError(err)
deleteState := allocConnectionState()
if tt.deleteConnectionStateId != "" {
deleteState.ConnectionId = tt.deleteConnectionStateId
} else {
deleteState.ConnectionId = tt.state.ConnectionId
}
deleteState.StartTime = initialState.StartTime
deletedRows, err := rw.Delete(ctx, &deleteState)
if tt.wantErr {
require.Error(err)
return
}
require.NoError(err)
if tt.wantRowsDeleted == 0 {
assert.Equal(tt.wantRowsDeleted, deletedRows)
return
}
assert.Equal(tt.wantRowsDeleted, deletedRows)
foundState := allocConnectionState()
err = rw.LookupWhere(ctx, &foundState, "connection_id = ? and start_time = ?", []any{tt.state.ConnectionId, initialState.StartTime})
require.Error(err)
assert.True(errors.IsNotFoundError(err))
})
}
}
func TestConnectionState_Clone(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
t.Run("valid", func(t *testing.T) {
assert := assert.New(t)
s := TestDefaultSession(t, conn, wrapper, iamRepo)
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1")
state := TestConnectionState(t, conn, c.PublicId, StatusConnected)
cp := state.Clone()
assert.Equal(cp.(*ConnectionState), state)
})
t.Run("not-equal", func(t *testing.T) {
assert := assert.New(t)
s := TestDefaultSession(t, conn, wrapper, iamRepo)
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1")
state := TestConnectionState(t, conn, c.PublicId, StatusConnected)
state2 := TestConnectionState(t, conn, c.PublicId, StatusConnected)
cp := state.Clone()
assert.NotEqual(cp.(*ConnectionState), state2)
})
}
func TestConnectionState_SetTableName(t *testing.T) {
t.Parallel()
defaultTableName := defaultConnectionStateTableName
tests := []struct {
name string
setNameTo string
want string
}{
{
name: "new-name",
setNameTo: "new-name",
want: "new-name",
},
{
name: "reset to default",
setNameTo: "",
want: defaultTableName,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
def := allocConnectionState()
require.Equal(defaultTableName, def.TableName())
s := allocConnectionState()
s.SetTableName(tt.setNameTo)
assert.Equal(tt.want, s.TableName())
})
}
}

@ -9,7 +9,6 @@ import (
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -236,96 +235,3 @@ func TestConnection_ImmutableFields(t *testing.T) {
})
}
}
func TestConnectionState_ImmutableFields(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
ts := timestamp.Timestamp{Timestamp: &timestamppb.Timestamp{Seconds: 0, Nanos: 0}}
_, _ = iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
session := TestDefaultSession(t, conn, wrapper, iamRepo)
connection := TestConnection(t, conn, session.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1")
state := TestConnectionState(t, conn, connection.PublicId, StatusConnected)
var new ConnectionState
err := rw.LookupWhere(context.Background(), &new, "connection_id = ? and state = ?", []any{state.ConnectionId, state.Status})
require.NoError(t, err)
tests := []struct {
name string
update *ConnectionState
fieldMask []string
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "session_id",
update: func() *ConnectionState {
s := new.Clone().(*ConnectionState)
s.ConnectionId = "sc_thisIsNotAValidId"
return s
}(),
fieldMask: []string{"PublicId"},
},
{
name: "status",
update: func() *ConnectionState {
s := new.Clone().(*ConnectionState)
s.Status = "closed"
return s
}(),
fieldMask: []string{"Status"},
wantErrMatch: errors.T(errors.NotSpecificIntegrity),
wantErrContains: "immutable column",
},
{
name: "start time",
update: func() *ConnectionState {
s := new.Clone().(*ConnectionState)
s.StartTime = &ts
return s
}(),
fieldMask: []string{"StartTime"},
wantErrMatch: errors.T(errors.InvalidFieldMask),
wantErrContains: "parameter violation",
},
{
name: "previous_end_time",
update: func() *ConnectionState {
s := new.Clone().(*ConnectionState)
s.PreviousEndTime = &ts
return s
}(),
fieldMask: []string{"PreviousEndTime"},
wantErrMatch: errors.T(errors.NotSpecificIntegrity),
wantErrContains: "immutable column",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
orig := new.Clone()
err := rw.LookupWhere(context.Background(), orig, "connection_id = ? and start_time = ?", []any{new.ConnectionId, new.StartTime})
require.NoError(err)
rowsUpdated, err := rw.Update(context.Background(), tt.update, tt.fieldMask, nil, db.WithSkipVetForWrite(true))
require.Error(err)
assert.Equal(0, rowsUpdated)
if tt.wantErrMatch != nil {
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted error %s and got: %s", tt.wantErrMatch.Code, err.Error())
}
if tt.wantErrContains != "" {
assert.Contains(err.Error(), tt.wantErrContains)
}
after := new.Clone()
err = rw.LookupWhere(context.Background(), after, "connection_id = ? and start_time = ?", []any{new.ConnectionId, new.StartTime})
require.NoError(err)
assert.Equal(orig.(*ConnectionState), after)
})
}
}

@ -65,10 +65,9 @@ func TestSessionConnectionCleanupJob(t *testing.T) {
sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithDbOpts(db.WithSkipVetForWrite(true)))
sess, _, err = sessionRepo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, []byte("foo"))
require.NoError(err)
c, cs, _, err := AuthorizeConnection(ctx, sessionRepo, connectionRepo, sess.GetPublicId(), serverId)
c, _, err := AuthorizeConnection(ctx, sessionRepo, connectionRepo, sess.GetPublicId(), serverId)
require.NoError(err)
require.Len(cs, 1)
require.Equal(StatusAuthorized, cs[0].Status)
require.Equal(StatusAuthorized, ConnectionStatusFromString(c.Status))
connIds = append(connIds, c.GetPublicId())
if i%2 == 0 {
connIdsByWorker[worker2.PublicId] = append(connIdsByWorker[worker2.PublicId], c.GetPublicId())
@ -81,7 +80,7 @@ func TestSessionConnectionCleanupJob(t *testing.T) {
// This is just to ensure we have a spread when we test it out.
for i, connId := range connIds {
if i%2 == 0 {
_, cs, err := connectionRepo.ConnectConnection(ctx, ConnectWith{
cc, err := connectionRepo.ConnectConnection(ctx, ConnectWith{
ConnectionId: connId,
ClientTcpAddress: "127.0.0.1",
ClientTcpPort: 22,
@ -90,18 +89,7 @@ func TestSessionConnectionCleanupJob(t *testing.T) {
UserClientIp: "127.0.0.1",
})
require.NoError(err)
require.Len(cs, 2)
var foundAuthorized, foundConnected bool
for _, status := range cs {
if status.Status == StatusAuthorized {
foundAuthorized = true
}
if status.Status == StatusConnected {
foundConnected = true
}
}
require.True(foundAuthorized)
require.True(foundConnected)
require.Equal(StatusConnected, ConnectionStatusFromString(cc.Status))
}
}
@ -130,14 +118,11 @@ func TestSessionConnectionCleanupJob(t *testing.T) {
require.True(ok)
require.Len(connIds, 6)
for _, connId := range connIds {
_, states, err := connectionRepo.LookupConnection(ctx, connId)
conn, err := connectionRepo.LookupConnection(ctx, connId)
require.NoError(err)
var foundClosed bool
for _, state := range states {
if state.Status == StatusClosed {
foundClosed = true
break
}
if ConnectionStatusFromString(conn.Status) == StatusClosed {
foundClosed = true
}
assert.Equal(closed, foundClosed)
}
@ -239,10 +224,9 @@ func TestCloseConnectionsForDeadWorkers(t *testing.T) {
sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithDbOpts(db.WithSkipVetForWrite(true)))
sess, _, err = repo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, []byte("foo"))
require.NoError(err)
c, cs, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId)
c, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId)
require.NoError(err)
require.Len(cs, 1)
require.Equal(StatusAuthorized, cs[0].Status)
require.Equal(StatusAuthorized, ConnectionStatusFromString(c.Status))
if i%3 == 0 {
worker1ConnIds = append(worker1ConnIds, c.GetPublicId())
} else if i%3 == 1 {
@ -263,7 +247,7 @@ func TestCloseConnectionsForDeadWorkers(t *testing.T) {
return s
}() {
if i%3 == 0 {
_, cs, err := connRepo.ConnectConnection(ctx, ConnectWith{
cc, err := connRepo.ConnectConnection(ctx, ConnectWith{
ConnectionId: connId,
ClientTcpAddress: "127.0.0.1",
ClientTcpPort: 22,
@ -272,18 +256,7 @@ func TestCloseConnectionsForDeadWorkers(t *testing.T) {
UserClientIp: "127.0.0.1",
})
require.NoError(err)
require.Len(cs, 2)
var foundAuthorized, foundConnected bool
for _, status := range cs {
if status.Status == StatusAuthorized {
foundAuthorized = true
}
if status.Status == StatusConnected {
foundConnected = true
}
}
require.True(foundAuthorized)
require.True(foundConnected)
require.Equal(StatusConnected, ConnectionStatusFromString(cc.Status))
} else if i%3 == 1 {
resp, err := connRepo.closeConnections(ctx, []CloseWith{
{
@ -293,19 +266,8 @@ func TestCloseConnectionsForDeadWorkers(t *testing.T) {
})
require.NoError(err)
require.Len(resp, 1)
cs := resp[0].ConnectionStates
require.Len(cs, 2)
var foundAuthorized, foundClosed bool
for _, status := range cs {
if status.Status == StatusAuthorized {
foundAuthorized = true
}
if status.Status == StatusClosed {
foundClosed = true
}
}
require.True(foundAuthorized)
require.True(foundClosed)
cs := resp[0].ConnectionState
require.Equal(StatusClosed, cs)
}
}
@ -344,9 +306,9 @@ func TestCloseConnectionsForDeadWorkers(t *testing.T) {
expected = StatusAuthorized
}
_, states, err := connRepo.LookupConnection(ctx, connId)
conn, err := connRepo.LookupConnection(ctx, connId)
require.NoError(err)
require.Equal(expected, states[0].Status, "expected latest status for %q (index %d) to be %v", connId, i, expected)
require.Equal(expected, ConnectionStatusFromString(conn.Status), "expected latest status for %q (index %d) to be %v", connId, i, expected)
}
}
@ -480,10 +442,9 @@ func TestCloseWorkerlessConnections(t *testing.T) {
sess, _, err = repo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, []byte("foo"))
require.NoError(err)
conn, cs, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), workerId)
conn, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), workerId)
require.NoError(err)
require.Len(cs, 1)
require.Equal(StatusAuthorized, cs[0].Status)
require.Equal(StatusAuthorized, ConnectionStatusFromString(conn.Status))
return conn
}
@ -513,21 +474,21 @@ func TestCloseWorkerlessConnections(t *testing.T) {
}})
require.NoError(err)
_, st, err := connRepo.LookupConnection(ctx, dActiveConn.GetPublicId())
con, err := connRepo.LookupConnection(ctx, dActiveConn.GetPublicId())
require.NoError(err)
require.Equal(StatusAuthorized, st[0].Status)
require.Equal(StatusAuthorized, ConnectionStatusFromString(con.Status))
_, st, err = connRepo.LookupConnection(ctx, dClosedConn.GetPublicId())
con, err = connRepo.LookupConnection(ctx, dClosedConn.GetPublicId())
require.NoError(err)
require.Equal(StatusClosed, st[0].Status)
require.Equal(StatusClosed, ConnectionStatusFromString(con.Status))
_, st, err = connRepo.LookupConnection(ctx, activeConn.GetPublicId())
con, err = connRepo.LookupConnection(ctx, activeConn.GetPublicId())
require.NoError(err)
require.Equal(StatusAuthorized, st[0].Status)
require.Equal(StatusAuthorized, ConnectionStatusFromString(con.Status))
_, st, err = connRepo.LookupConnection(ctx, closedConn.GetPublicId())
con, err = connRepo.LookupConnection(ctx, closedConn.GetPublicId())
require.NoError(err)
require.Equal(StatusClosed, st[0].Status)
require.Equal(StatusClosed, ConnectionStatusFromString(con.Status))
// Run the job
numClosed, err := job.closeWorkerlessConnections(ctx)
@ -535,19 +496,19 @@ func TestCloseWorkerlessConnections(t *testing.T) {
assert.Equal(t, 1, numClosed)
// This is the only one that the job should have actually closed.
_, st, err = connRepo.LookupConnection(ctx, dActiveConn.GetPublicId())
con, err = connRepo.LookupConnection(ctx, dActiveConn.GetPublicId())
require.NoError(err)
require.Equal(StatusClosed, st[0].Status)
require.Equal(StatusClosed, ConnectionStatusFromString(con.Status))
_, st, err = connRepo.LookupConnection(ctx, dClosedConn.GetPublicId())
con, err = connRepo.LookupConnection(ctx, dClosedConn.GetPublicId())
require.NoError(err)
require.Equal(StatusClosed, st[0].Status)
require.Equal(StatusClosed, ConnectionStatusFromString(con.Status))
_, st, err = connRepo.LookupConnection(ctx, activeConn.GetPublicId())
con, err = connRepo.LookupConnection(ctx, activeConn.GetPublicId())
require.NoError(err)
require.Equal(StatusAuthorized, st[0].Status)
require.Equal(StatusAuthorized, ConnectionStatusFromString(con.Status))
_, st, err = connRepo.LookupConnection(ctx, closedConn.GetPublicId())
con, err = connRepo.LookupConnection(ctx, closedConn.GetPublicId())
require.NoError(err)
require.Equal(StatusClosed, st[0].Status)
require.Equal(StatusClosed, ConnectionStatusFromString(con.Status))
}

@ -124,6 +124,14 @@ from
session_connection_limit, session_connection_count;
`
// connectConnection sets the connected time range to (now, infinity) to
// indicate the connection is connected.
connectConnection = `
update session_connection
set connected_time_range=tstzrange(now(),'infinity')
where public_id=@public_id
`
terminateSessionIfPossible = `
-- is terminate_session_id in a canceling state
with session_version as (
@ -197,17 +205,11 @@ from
select
session_id
from
session_connection
where public_id in (
select
connection_id
from
session_connection_state
where
state != 'closed' and
end_time is null
)
)
session_connection
where
upper(connected_time_range) > now() or
connected_time_range is null
)
`
// termSessionUpdate is one stmt that terminates sessions for the following
@ -271,21 +273,12 @@ where
)
) and
-- make sure there are no existing connections
us.public_id not in (
select
session_id
from
session_connection
where public_id in (
select
connection_id
from
session_connection_state
where
state != 'closed' and
end_time is null
)
);
us.public_id not in (
select session_id
from session_connection
where upper(connected_time_range) > now()
or connected_time_range is null
);
`
// closeConnectionsForDeadServersCte finds connections that are:
@ -336,22 +329,20 @@ where
and closed_reason is null
returning public_id;
`
orphanedConnectionsCte = `
-- Find connections that are not closed so we can reference those IDs
with
unclosed_connections as (
select connection_id
from session_connection_state
select public_id
from session_connection
where
-- It's the current state
end_time is null
-- Current state isn't closed state
and state in ('authorized', 'connected')
-- It's not closed
upper(connected_time_range) > now() or
connected_time_range is null
-- It's not in limbo between when it moved into this state and when
-- it started being reported by the worker, which is roughly every
-- 2-3 seconds
and start_time < wt_sub_seconds_from_now(@worker_state_delay_seconds)
and update_time < wt_sub_seconds_from_now(@worker_state_delay_seconds)
),
connections_to_close as (
select public_id
@ -360,7 +351,7 @@ with
-- Related to the worker that just reported to us
worker_id = @worker_id
-- Only unclosed ones
and public_id in (select connection_id from unclosed_connections)
and public_id in (select public_id from unclosed_connections)
-- These are connection IDs that just got reported to us by the given
-- worker, so they should not be considered closed.
%s

@ -134,19 +134,18 @@ func (r *ConnectionRepository) updateBytesUpBytesDown(ctx context.Context, conns
// If authorization is success, it creates/stores a new connection in the repo
// and returns it, along with its states. If the authorization fails, it
// an error with Code InvalidSessionState.
func (r *ConnectionRepository) AuthorizeConnection(ctx context.Context, sessionId, workerId string) (*Connection, []*ConnectionState, error) {
func (r *ConnectionRepository) AuthorizeConnection(ctx context.Context, sessionId, workerId string) (*Connection, error) {
const op = "session.(ConnectionRepository).AuthorizeConnection"
if sessionId == "" {
return nil, nil, errors.Wrap(ctx, status.Error(codes.FailedPrecondition, "missing session id"), op, errors.WithCode(errors.InvalidParameter))
return nil, errors.Wrap(ctx, status.Error(codes.FailedPrecondition, "missing session id"), op, errors.WithCode(errors.InvalidParameter))
}
connectionId, err := newConnectionId(ctx)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
return nil, errors.Wrap(ctx, err, op)
}
connection := AllocConnection()
connection.PublicId = connectionId
var connectionStates []*ConnectionState
_, err = r.writer.DoTx(
ctx,
db.StdRetryCnt,
@ -166,31 +165,26 @@ func (r *ConnectionRepository) AuthorizeConnection(ctx context.Context, sessionI
if err := reader.LookupById(ctx, &connection); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for session %s", sessionId)))
}
connectionStates, err = fetchConnectionStates(ctx, reader, connectionId, db.WithOrder("start_time desc"))
if err != nil {
return errors.Wrap(ctx, err, op)
}
return nil
},
)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
return nil, errors.Wrap(ctx, err, op)
}
return &connection, connectionStates, nil
return &connection, nil
}
// LookupConnection will look up a connection in the repository and return the connection
// with its states. If the connection is not found, it will return nil, nil, nil.
// with its state. If the connection is not found, it will return nil, nil.
// No options are currently supported.
func (r *ConnectionRepository) LookupConnection(ctx context.Context, connectionId string, _ ...Option) (*Connection, []*ConnectionState, error) {
func (r *ConnectionRepository) LookupConnection(ctx context.Context, connectionId string, _ ...Option) (*Connection, error) {
const op = "session.(ConnectionRepository).LookupConnection"
if connectionId == "" {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing connectionId id")
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing connectionId id")
}
connection := AllocConnection()
connection.PublicId = connectionId
var states []*ConnectionState
_, err := r.writer.DoTx(
ctx,
db.StdRetryCnt,
@ -199,20 +193,16 @@ func (r *ConnectionRepository) LookupConnection(ctx context.Context, connectionI
if err := read.LookupById(ctx, &connection); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for %s", connectionId)))
}
var err error
if states, err = fetchConnectionStates(ctx, read, connectionId, db.WithOrder("start_time desc")); err != nil {
return errors.Wrap(ctx, err, op)
}
return nil
},
)
if err != nil {
if errors.IsNotFoundError(err) {
return nil, nil, nil
return nil, nil
}
return nil, nil, errors.Wrap(ctx, err, op)
return nil, errors.Wrap(ctx, err, op)
}
return &connection, states, nil
return &connection, nil
}
// ListConnectionsBySessionId will list connections by session ID. Supports the
@ -231,14 +221,13 @@ func (r *ConnectionRepository) ListConnectionsBySessionId(ctx context.Context, s
}
// ConnectConnection updates a connection in the repo with a state of "connected".
func (r *ConnectionRepository) ConnectConnection(ctx context.Context, c ConnectWith) (*Connection, []*ConnectionState, error) {
func (r *ConnectionRepository) ConnectConnection(ctx context.Context, c ConnectWith) (*Connection, error) {
const op = "session.(ConnectionRepository).ConnectConnection"
// ConnectWith.validate will check all the fields...
if err := c.validate(ctx); err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
return nil, errors.Wrap(ctx, err, op)
}
var connection Connection
var connectionStates []*ConnectionState
_, err := r.writer.DoTx(
ctx,
db.StdRetryCnt,
@ -266,31 +255,33 @@ func (r *ConnectionRepository) ConnectConnection(ctx context.Context, c ConnectW
// return err, which will result in a rollback of the update
return errors.New(ctx, errors.MultipleRecords, op, "more than 1 resource would have been updated")
}
newState, err := NewConnectionState(ctx, connection.PublicId, StatusConnected)
// Set the lower bound of the connected_time_range to indicate the connection is connected
rowsUpdated, err = w.Exec(ctx, connectConnection, []any{
sql.Named("public_id", c.ConnectionId),
})
if err != nil {
return errors.Wrap(ctx, err, op)
}
if err := w.Create(ctx, newState); err != nil {
return errors.Wrap(ctx, err, op)
if rowsUpdated != 1 {
return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("unable to connect connection %s", c.ConnectionId)))
}
connectionStates, err = fetchConnectionStates(ctx, reader, c.ConnectionId, db.WithOrder("start_time desc"))
if err != nil {
return errors.Wrap(ctx, err, op)
if err := reader.LookupById(ctx, &connection); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for connection %s", c.ConnectionId)))
}
return nil
},
)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
return nil, errors.Wrap(ctx, err, op)
}
return &connection, connectionStates, nil
return &connection, nil
}
// closeConnectionResp is just a wrapper for the response from CloseConnections.
// It wraps the connection and its states for each connection closed.
// It wraps the connection and its state for each connection closed.
type closeConnectionResp struct {
Connection *Connection
ConnectionStates []*ConnectionState
Connection *Connection
ConnectionState ConnectionStatus
}
// closeConnections set's a connection's state to "closed" in the repo. It's
@ -318,8 +309,8 @@ func (r *ConnectionRepository) closeConnections(ctx context.Context, closeWith [
updateConnection.BytesUp = cw.BytesUp
updateConnection.BytesDown = cw.BytesDown
updateConnection.ClosedReason = cw.ClosedReason.String()
// updating the ClosedReason will trigger an insert into the
// session_connection_state with a state of closed.
// updating the ClosedReason will trigger the session_connection to set the
// upper limit of connection_time_range to indicate the connection is closed.
rowsUpdated, err := w.Update(
ctx,
&updateConnection,
@ -332,13 +323,9 @@ func (r *ConnectionRepository) closeConnections(ctx context.Context, closeWith [
if rowsUpdated != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("%d would have been updated for connection %s", rowsUpdated, cw.ConnectionId))
}
states, err := fetchConnectionStates(ctx, reader, cw.ConnectionId, db.WithOrder("start_time desc"))
if err != nil {
return errors.Wrap(ctx, err, op)
}
resp = append(resp, closeConnectionResp{
Connection: &updateConnection,
ConnectionStates: states,
Connection: &updateConnection,
ConnectionState: StatusClosed,
})
}
@ -441,15 +428,3 @@ func (r *ConnectionRepository) closeOrphanedConnections(ctx context.Context, wor
}
return orphanedConns, nil
}
func fetchConnectionStates(ctx context.Context, r db.Reader, connectionId string, opt ...db.Option) ([]*ConnectionState, error) {
const op = "session.fetchConnectionStates"
var states []*ConnectionState
if err := r.SearchWhere(ctx, &states, "connection_id = ?", []any{connectionId}, opt...); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
if len(states) == 0 {
return nil, nil
}
return states, nil
}

@ -240,7 +240,7 @@ func TestRepository_ConnectConnection(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
c, cs, err := connRepo.ConnectConnection(context.Background(), tt.connectWith)
c, err := connRepo.ConnectConnection(context.Background(), tt.connectWith)
if tt.wantErr {
require.Error(err)
assert.Truef(errors.Match(errors.T(tt.wantIsError), err), "unexpected error %s", err.Error())
@ -248,9 +248,8 @@ func TestRepository_ConnectConnection(t *testing.T) {
}
require.NoError(err)
require.NotNil(c)
require.NotNil(cs)
assert.Equal(StatusConnected, cs[0].Status)
gotConn, _, err := connRepo.LookupConnection(context.Background(), c.PublicId)
assert.Equal(StatusConnected, ConnectionStatusFromString(c.Status))
gotConn, err := connRepo.LookupConnection(context.Background(), c.PublicId)
require.NoError(err)
assert.Equal(tt.connectWith.ClientTcpAddress, gotConn.ClientTcpAddress)
assert.Equal(tt.connectWith.ClientTcpPort, gotConn.ClientTcpPort)
@ -336,7 +335,7 @@ func TestRepository_DeleteConnection(t *testing.T) {
}
assert.NoError(err)
assert.Equal(tt.wantRowsDeleted, deletedRows)
found, _, err := connRepo.LookupConnection(context.Background(), tt.args.connection.PublicId)
found, err := connRepo.LookupConnection(context.Background(), tt.args.connection.PublicId)
assert.NoError(err)
assert.Nil(found)
@ -378,10 +377,9 @@ func TestRepository_orphanedConnections(t *testing.T) {
sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithDbOpts(db.WithSkipVetForWrite(true)))
sess, _, err = repo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, []byte("foo"))
require.NoError(err)
c, cs, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId)
c, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId)
require.NoError(err)
require.Len(cs, 1)
require.Equal(StatusAuthorized, cs[0].Status)
require.Equal(StatusAuthorized, ConnectionStatusFromString(c.Status))
connIds = append(connIds, c.GetPublicId())
if i%2 == 0 {
worker2ConnIds = append(worker2ConnIds, c.GetPublicId())
@ -394,7 +392,7 @@ func TestRepository_orphanedConnections(t *testing.T) {
// This is just to ensure we have a spread when we test it out.
for i, connId := range connIds {
if i%2 == 0 {
_, cs, err := connRepo.ConnectConnection(ctx, ConnectWith{
cc, err := connRepo.ConnectConnection(ctx, ConnectWith{
ConnectionId: connId,
ClientTcpAddress: "127.0.0.1",
ClientTcpPort: 22,
@ -403,18 +401,7 @@ func TestRepository_orphanedConnections(t *testing.T) {
UserClientIp: "127.0.0.1",
})
require.NoError(err)
require.Len(cs, 2)
var foundAuthorized, foundConnected bool
for _, status := range cs {
if status.Status == StatusAuthorized {
foundAuthorized = true
}
if status.Status == StatusConnected {
foundConnected = true
}
}
require.True(foundAuthorized)
require.True(foundConnected)
require.Equal(StatusConnected, ConnectionStatusFromString(cc.Status))
}
}
@ -517,8 +504,8 @@ func TestRepository_CloseConnections(t *testing.T) {
assert.Equal(len(tt.closeWith), len(resp))
for _, r := range resp {
require.NotNil(r.Connection)
require.NotNil(r.ConnectionStates)
assert.Equal(StatusClosed, r.ConnectionStates[0].Status)
require.NotNil(r.ConnectionState)
assert.Equal(StatusClosed, r.ConnectionState)
}
})
}
@ -561,7 +548,7 @@ func TestUpdateBytesUpDown(t *testing.T) {
// Assert that the bytes up and down values have been persisted.
for i := 0; i < len(conns); i++ {
c, _, err := connRepo.LookupConnection(ctx, conns[i].GetPublicId())
c, err := connRepo.LookupConnection(ctx, conns[i].GetPublicId())
require.NoError(t, err)
require.Equal(t, conns[i].BytesUp, c.BytesUp)
@ -604,10 +591,105 @@ func TestUpdateBytesUpDown(t *testing.T) {
// BytesUp and BytesDown values should be set to the old ones.
for i := 0; i < len(conns); i++ {
c, _, err := connRepo.LookupConnection(ctx, conns[i].GetPublicId())
c, err := connRepo.LookupConnection(ctx, conns[i].GetPublicId())
require.NoError(t, err)
require.Equal(t, conns[i].BytesUp, c.BytesUp)
require.Equal(t, conns[i].BytesDown, c.BytesDown)
}
}
func TestRepository_StateTransitions(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(t, err)
connRepo, err := NewConnectionRepository(ctx, rw, rw, kms)
require.NoError(t, err)
s := TestDefaultSession(t, conn, wrapper, iamRepo)
tofu := TestTofu(t)
_, _, err = repo.ActivateSession(context.Background(), s.PublicId, s.Version, tofu)
require.NoError(t, err)
// First connection will transition authorized -> connected -> closed
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1")
cw := ConnectWith{
ConnectionId: c.PublicId,
ClientTcpAddress: "127.0.0.1",
ClientTcpPort: 22,
EndpointTcpAddress: "127.0.0.1",
EndpointTcpPort: 2222,
UserClientIp: "127.0.0.1",
}
gotConn, err := connRepo.LookupConnection(context.Background(), c.PublicId)
require.NoError(t, err)
require.NotNil(t, gotConn)
require.Equal(t, StatusAuthorized, ConnectionStatusFromString(gotConn.Status))
_, err = connRepo.ConnectConnection(context.Background(), cw)
require.NoError(t, err)
gotConn, err = connRepo.LookupConnection(context.Background(), c.PublicId)
require.NoError(t, err)
require.NotNil(t, gotConn)
require.Equal(t, StatusConnected, ConnectionStatusFromString(gotConn.Status))
// Attempt to connect again, expect failure
_, err = connRepo.ConnectConnection(context.Background(), cw)
require.Error(t, err)
require.Contains(t, err.Error(), "Invalid state transition from connected")
closeWith := CloseWith{
ConnectionId: c.PublicId,
ClosedReason: ConnectionClosedByUser,
}
resp, err := connRepo.closeConnections(context.Background(), []CloseWith{closeWith})
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, StatusClosed, resp[0].ConnectionState)
// Second connection will transition from authorized -> closed
c2 := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1")
gotConn, err = connRepo.LookupConnection(context.Background(), c2.PublicId)
require.NoError(t, err)
require.NotNil(t, gotConn)
require.Equal(t, StatusAuthorized, ConnectionStatusFromString(gotConn.Status))
closeWith2 := CloseWith{
ConnectionId: c2.PublicId,
ClosedReason: ConnectionClosedByUser,
}
resp, err = connRepo.closeConnections(context.Background(), []CloseWith{closeWith2})
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, StatusClosed, resp[0].ConnectionState)
gotConn, err = connRepo.LookupConnection(context.Background(), c2.PublicId)
require.NoError(t, err)
require.NotNil(t, gotConn)
require.Equal(t, StatusClosed, ConnectionStatusFromString(gotConn.Status))
// Now try to connect it while closed and ensure it can't transition to connected
cw2 := ConnectWith{
ConnectionId: c2.PublicId,
ClientTcpAddress: "127.0.0.1",
ClientTcpPort: 22,
EndpointTcpAddress: "127.0.0.1",
EndpointTcpPort: 2222,
UserClientIp: "127.0.0.1",
}
_, err = connRepo.ConnectConnection(context.Background(), cw2)
require.Error(t, err)
require.Contains(t, err.Error(), "Invalid state transition from closed")
gotConn, err = connRepo.LookupConnection(context.Background(), c2.PublicId)
require.NoError(t, err)
require.NotNil(t, gotConn)
require.Equal(t, StatusClosed, ConnectionStatusFromString(gotConn.Status))
}

@ -1149,12 +1149,10 @@ func TestRepository_TerminateCompletedSessions(t *testing.T) {
conn, err := connRepo.ListConnectionsBySessionId(context.Background(), found.PublicId)
require.NoError(err)
for _, sc := range conn {
c, cs, err := connRepo.LookupConnection(context.Background(), sc.PublicId)
c, err := connRepo.LookupConnection(context.Background(), sc.PublicId)
require.NoError(err)
assert.NotEmpty(c.ClosedReason)
for _, s := range cs {
t.Logf("%s session %s connection state %s at %s", found.PublicId, s.ConnectionId, s.Status, s.EndTime)
}
t.Logf("%s session connection state %s", found.PublicId, c.Status)
}
} else {
t.Logf("not terminated %s has a connection limit of %d", found.PublicId, found.ConnectionLimit)
@ -1162,11 +1160,9 @@ func TestRepository_TerminateCompletedSessions(t *testing.T) {
conn, err := connRepo.ListConnectionsBySessionId(context.Background(), found.PublicId)
require.NoError(err)
for _, sc := range conn {
cs, err := fetchConnectionStates(context.Background(), rw, sc.PublicId)
c, err := connRepo.LookupConnection(context.Background(), sc.PublicId)
require.NoError(err)
for _, s := range cs {
t.Logf("%s session %s connection state %s at %s", found.PublicId, s.ConnectionId, s.Status, s.EndTime)
}
t.Logf("%s session connection state %s", found.PublicId, c.Status)
}
}
}

@ -17,17 +17,17 @@ import (
// If any of these criteria is not met, it returns an error with Code InvalidSessionState.
func AuthorizeConnection(ctx context.Context, sessionRepoFn *Repository, connectionRepoFn *ConnectionRepository,
sessionId, workerId string, opt ...Option,
) (*Connection, []*ConnectionState, *AuthzSummary, error) {
) (*Connection, *AuthzSummary, error) {
const op = "session.AuthorizeConnection"
connection, connectionStates, err := connectionRepoFn.AuthorizeConnection(ctx, sessionId, workerId)
connection, err := connectionRepoFn.AuthorizeConnection(ctx, sessionId, workerId)
if err != nil {
return nil, nil, nil, errors.Wrap(ctx, err, op)
return nil, nil, errors.Wrap(ctx, err, op)
}
authzSummary, err := sessionRepoFn.sessionAuthzSummary(ctx, sessionId)
if err != nil {
return nil, nil, nil, errors.Wrap(ctx, err, op)
return nil, nil, errors.Wrap(ctx, err, op)
}
return connection, connectionStates, authzSummary, nil
return connection, authzSummary, nil
}

@ -138,7 +138,7 @@ func TestService_AuthorizeConnection(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
c, cs, authzInfo, err := AuthorizeConnection(context.Background(), repo, connRepo, tt.session.PublicId, testServer)
c, authzInfo, err := AuthorizeConnection(context.Background(), repo, connRepo, tt.session.PublicId, testServer)
if tt.wantErr {
require.Error(err)
// TODO (jimlambrt 9/2020): add in tests for errorsIs once we
@ -150,8 +150,8 @@ func TestService_AuthorizeConnection(t *testing.T) {
}
require.NoError(err)
require.NotNil(c)
require.NotNil(cs)
assert.Equal(StatusAuthorized, cs[0].Status)
require.NotNil(c.Status)
assert.Equal(StatusAuthorized, ConnectionStatusFromString(c.Status))
assert.True(authzInfo.ExpirationTime.GetTimestamp().AsTime().Sub(tt.wantAuthzInfo.ExpirationTime.GetTimestamp().AsTime()) < 10*time.Millisecond)
tt.wantAuthzInfo.ExpirationTime = authzInfo.ExpirationTime

@ -131,8 +131,8 @@ func TestServiceCloseConnections(t *testing.T) {
for _, r := range resp {
require.NotNil(r.Connection)
require.NotNil(r.ConnectionStates)
assert.Equal(StatusClosed, r.ConnectionStates[0].Status)
require.NotNil(r.ConnectionState)
assert.Equal(StatusClosed, r.ConnectionState)
}
// Ensure session is in the state we want- terminated if all conns closed, else active

@ -94,7 +94,7 @@ func TestWorkerStatusReport(t *testing.T) {
require.NoError(t, err)
require.NoError(t, err)
_, _, err = connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId)
_, err = connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId)
require.NoError(t, err)
_, err = repo.CancelSession(ctx, sess.PublicId, sess.Version)
@ -126,7 +126,7 @@ func TestWorkerStatusReport(t *testing.T) {
require.NoError(t, err)
require.NoError(t, err)
connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId)
connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId)
require.NoError(t, err)
return testCase{
worker: worker,
@ -160,7 +160,7 @@ func TestWorkerStatusReport(t *testing.T) {
tofu := session.TestTofu(t)
sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, tofu)
require.NoError(t, err)
connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId)
connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId)
require.NoError(t, err)
_, err = repo.CancelSession(ctx, sess.PublicId, sess.Version)
require.NoError(t, err)
@ -224,7 +224,7 @@ func TestWorkerStatusReport(t *testing.T) {
tofu := session.TestTofu(t)
sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, tofu)
require.NoError(t, err)
connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId)
connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId)
require.NoError(t, err)
_, err = repo.CancelSession(ctx, sess.PublicId, sess.Version)
require.NoError(t, err)
@ -242,7 +242,7 @@ func TestWorkerStatusReport(t *testing.T) {
tofu2 := session.TestTofu(t)
sess2, _, err = repo.ActivateSession(ctx, sess2.PublicId, sess2.Version, tofu2)
require.NoError(t, err)
connection2, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId)
connection2, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId)
require.NoError(t, err)
_, err = repo.CancelSession(ctx, sess2.PublicId, sess2.Version)
require.NoError(t, err)
@ -295,7 +295,7 @@ func TestWorkerStatusReport(t *testing.T) {
tofu := session.TestTofu(t)
sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, tofu)
require.NoError(t, err)
connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId)
connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId)
require.NoError(t, err)
sess2 := session.TestSession(t, conn, wrapper, session.ComposedOf{
@ -311,7 +311,7 @@ func TestWorkerStatusReport(t *testing.T) {
tofu2 := session.TestTofu(t)
sess2, _, err = repo.ActivateSession(ctx, sess2.PublicId, sess2.Version, tofu2)
require.NoError(t, err)
connection2, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId)
connection2, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId)
require.NoError(t, err)
require.NotEqual(t, connection.PublicId, connection2.PublicId)
@ -348,7 +348,7 @@ func TestWorkerStatusReport(t *testing.T) {
tofu := session.TestTofu(t)
sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, tofu)
require.NoError(t, err)
connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId)
connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId)
require.NoError(t, err)
_, err = repo.CancelSession(ctx, sess.PublicId, sess.Version)
require.NoError(t, err)
@ -366,9 +366,9 @@ func TestWorkerStatusReport(t *testing.T) {
tofu2 := session.TestTofu(t)
sess2, _, err = repo.ActivateSession(ctx, sess2.PublicId, sess2.Version, tofu2)
require.NoError(t, err)
connection2, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId)
connection2, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId)
require.NoError(t, err)
connection3, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId)
connection3, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId)
require.NoError(t, err)
_, err = repo.CancelSession(ctx, sess2.PublicId, sess2.Version)
require.NoError(t, err)
@ -417,12 +417,10 @@ func TestWorkerStatusReport(t *testing.T) {
require.NoError(err)
assert.ElementsMatch(tc.want, got)
for _, dc := range tc.orphanedConnections {
gotConn, states, err := connRepo.LookupConnection(ctx, dc)
gotConn, err := connRepo.LookupConnection(ctx, dc)
require.NoError(err)
assert.Equal(session.ConnectionSystemError, session.ClosedReason(gotConn.ClosedReason))
assert.Equal(2, len(states))
assert.Nil(states[0].EndTime)
assert.Equal(session.StatusClosed, states[0].Status)
assert.Equal(session.StatusClosed, session.ConnectionStatusFromString(gotConn.Status))
}
})
}

@ -41,26 +41,9 @@ func TestConnection(t testing.TB, conn *db.DB, sessionId, clientTcpAddr string,
err = rw.Create(ctx, c)
require.NoError(err)
connectedState, err := NewConnectionState(ctx, c.PublicId, StatusConnected)
require.NoError(err)
err = rw.Create(ctx, connectedState)
require.NoError(err)
return c
}
// TestConnectionState creates a test connection state for the connectionId in the repository.
func TestConnectionState(t testing.TB, conn *db.DB, connectionId string, state ConnectionStatus) *ConnectionState {
t.Helper()
ctx := context.Background()
require := require.New(t)
rw := db.New(conn)
s, err := NewConnectionState(ctx, connectionId, state)
require.NoError(err)
err = rw.Create(context.Background(), s)
require.NoError(err)
return s
}
// TestState creates a test state for the sessionId in the repository.
func TestState(t testing.TB, conn *db.DB, sessionId string, state Status) *State {
t.Helper()

@ -50,29 +50,6 @@ func Test_TestConnection(t *testing.T) {
require.NotNil(c)
}
func Test_TestConnectionState(t *testing.T) {
assert, require := assert.New(t), require.New(t)
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
s := TestDefaultSession(t, conn, wrapper, iamRepo)
require.NotNil(s)
assert.NotEmpty(s.PublicId)
c := TestConnection(t, conn, s.PublicId, "0.0.0.0", 22, "0.0.0.0", 2222, "127.0.0.1")
require.NotNil(c)
assert.NotEmpty(c.PublicId)
cs := TestConnectionState(t, conn, c.PublicId, StatusClosed)
require.NotNil(cs)
rw := db.New(conn)
var initialState ConnectionState
err := rw.LookupWhere(context.Background(), &initialState, "connection_id = ? and state = ?", []any{cs.ConnectionId, cs.Status})
require.NoError(err)
assert.NotEmpty(initialState.StartTime)
}
func Test_TestWorker(t *testing.T) {
require := require.New(t)
conn, _ := db.TestSetup(t, "postgres")

@ -155,11 +155,9 @@ func (s *TestSession) ExpectConnectionStateOnController(
}
for i, conn := range conns {
_, states, err := connectionRepo.LookupConnection(ctx, conn.PublicId, nil)
c, err := connectionRepo.LookupConnection(ctx, conn.PublicId, nil)
require.NoError(err)
// Look at the first state in the returned list, which will
// be the most recent state.
actualStates[i] = states[0].Status
actualStates[i] = session.ConnectionStatusFromString(c.Status)
}
if reflect.DeepEqual(expectStates, actualStates) {

Loading…
Cancel
Save