From 2191aa1d034dd7c9acd784eeefb078eb550b6802 Mon Sep 17 00:00:00 2001 From: Irena Rindos Date: Tue, 23 Jul 2024 14:19:21 -0400 Subject: [PATCH] refact(session connection): remove session connection state table (#4617) * refact(session connection): remove session connection state table --- .../daemon/cluster/handlers/worker_service.go | 16 +- .../handlers/worker_service_status_test.go | 18 +- .../oss/postgres/0/50_session.up.sql | 1 + .../oss/postgres/0/51_connection.up.sql | 1 + .../15/01_wh_rename_key_columns.up.sql | 1 + .../27/01_disable_terminate_session.up.sql | 1 + .../postgres/27/02_wh_session_facts.up.sql | 1 + .../01_remove_session_connection_state.up.sql | 247 ++++++++++++++++++ .../session_connection_state_transition.sql | 35 +++ .../tests/wh/session_connection/update.sql | 3 +- internal/server/repository_worker_test.go | 6 +- internal/session/connection.go | 5 +- internal/session/connection_state.go | 137 +--------- internal/session/connection_state_test.go | 217 --------------- internal/session/immutable_fields_test.go | 94 ------- internal/session/job_session_cleanup_test.go | 105 +++----- internal/session/query.go | 61 ++--- internal/session/repository_connection.go | 85 +++--- .../session/repository_connection_test.go | 132 ++++++++-- internal/session/repository_session_test.go | 12 +- .../session/service_authorize_connection.go | 10 +- .../service_authorize_connection_test.go | 6 +- .../session/service_close_connections_test.go | 4 +- .../service_worker_status_report_test.go | 26 +- internal/session/testing.go | 17 -- internal/session/testing_test.go | 23 -- internal/tests/helper/testing_helper.go | 6 +- 27 files changed, 542 insertions(+), 728 deletions(-) create mode 100644 internal/db/schema/migrations/oss/postgres/89/01_remove_session_connection_state.up.sql create mode 100644 internal/db/sqltest/tests/session/session_connection_state_transition.sql delete mode 100644 internal/session/connection_state_test.go diff --git a/internal/daemon/cluster/handlers/worker_service.go b/internal/daemon/cluster/handlers/worker_service.go index 528cf341f9..2e3209f2a7 100644 --- a/internal/daemon/cluster/handlers/worker_service.go +++ b/internal/daemon/cluster/handlers/worker_service.go @@ -622,16 +622,13 @@ func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs return nil, status.Errorf(codes.NotFound, "worker not found with name %q", req.GetWorkerId()) } - connectionInfo, connStates, err := connectionRepo.AuthorizeConnection(ctx, req.GetSessionId(), w.GetPublicId()) + connectionInfo, err := connectionRepo.AuthorizeConnection(ctx, req.GetSessionId(), w.GetPublicId()) if err != nil { return nil, err } if connectionInfo == nil { return nil, status.Error(codes.Internal, "Invalid authorize connection response.") } - if len(connStates) == 0 { - return nil, status.Error(codes.Internal, "Invalid connection state in authorize response.") - } sessInfo, authzSummary, err := sessionRepo.LookupSession(ctx, req.GetSessionId()) if err != nil { @@ -648,7 +645,7 @@ func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs ret := &pbs.AuthorizeConnectionResponse{ ConnectionId: connectionInfo.GetPublicId(), - Status: connStates[0].Status.ProtoVal(), + Status: session.ConnectionStatusFromString(connectionInfo.Status).ProtoVal(), ConnectionsLeft: authzSummary.ConnectionLimit, Route: route, } @@ -680,7 +677,7 @@ func (ws *workerServiceServer) ConnectConnection(ctx context.Context, req *pbs.C return nil, status.Errorf(codes.Internal, "error getting session repo: %v", err) } - connectionInfo, connStates, err := connRepo.ConnectConnection(ctx, session.ConnectWith{ + connectionInfo, err := connRepo.ConnectConnection(ctx, session.ConnectWith{ ConnectionId: req.GetConnectionId(), ClientTcpAddress: req.GetClientTcpAddress(), ClientTcpPort: req.GetClientTcpPort(), @@ -696,7 +693,7 @@ func (ws *workerServiceServer) ConnectConnection(ctx context.Context, req *pbs.C } return &pbs.ConnectConnectionResponse{ - Status: connStates[0].Status.ProtoVal(), + Status: session.ConnectionStatusFromString(connectionInfo.Status).ProtoVal(), }, nil } @@ -742,12 +739,9 @@ func (ws *workerServiceServer) CloseConnection(ctx context.Context, req *pbs.Clo if v.Connection == nil { return nil, status.Errorf(codes.Internal, "No connection found while closing one of the connection IDs: %v", closeIds) } - if len(v.ConnectionStates) == 0 { - return nil, status.Errorf(codes.Internal, "No connection states found while closing one of the connection IDs: %v", closeIds) - } closeData = append(closeData, &pbs.CloseConnectionResponseData{ ConnectionId: v.Connection.GetPublicId(), - Status: v.ConnectionStates[0].Status.ProtoVal(), + Status: v.ConnectionState.ProtoVal(), }) } diff --git a/internal/daemon/cluster/handlers/worker_service_status_test.go b/internal/daemon/cluster/handlers/worker_service_status_test.go index fdda29fbf9..1183372751 100644 --- a/internal/daemon/cluster/handlers/worker_service_status_test.go +++ b/internal/daemon/cluster/handlers/worker_service_status_test.go @@ -97,7 +97,7 @@ func TestStatus(t *testing.T) { tofu := session.TestTofu(t) canceledSess, _, err = repo.ActivateSession(ctx, canceledSess.PublicId, canceledSess.Version, tofu) require.NoError(t, err) - canceledConn, _, err := connRepo.AuthorizeConnection(ctx, canceledSess.PublicId, worker1.PublicId) + canceledConn, err := connRepo.AuthorizeConnection(ctx, canceledSess.PublicId, worker1.PublicId) require.NoError(t, err) canceledSess, err = repo.CancelSession(ctx, canceledSess.PublicId, canceledSess.Version) @@ -120,7 +120,7 @@ func TestStatus(t *testing.T) { s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce) require.NotNil(t, s) - connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId) + connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId) require.NoError(t, err) cases := []struct { @@ -562,7 +562,7 @@ func TestStatusSessionClosed(t *testing.T) { s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce) require.NotNil(t, s) - connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId) + connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId) require.NoError(t, err) cases := []struct { @@ -757,9 +757,9 @@ func TestStatusDeadConnection(t *testing.T) { s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce) require.NotNil(t, s) - connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId) + connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId) require.NoError(t, err) - deadConn, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker1.PublicId) + deadConn, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker1.PublicId) require.NoError(t, err) require.NotEqual(t, deadConn.PublicId, connection.PublicId) @@ -823,12 +823,10 @@ func TestStatusDeadConnection(t *testing.T) { ), ) - gotConn, states, err := connRepo.LookupConnection(ctx, deadConn.PublicId) + gotConn, err := connRepo.LookupConnection(ctx, deadConn.PublicId) require.NoError(t, err) assert.Equal(t, session.ConnectionSystemError, session.ClosedReason(gotConn.ClosedReason)) - assert.Equal(t, 2, len(states)) - assert.Nil(t, states[0].EndTime) - assert.Equal(t, session.StatusClosed, states[0].Status) + assert.Equal(t, session.StatusClosed, session.ConnectionStatusFromString(gotConn.Status)) } func TestStatusWorkerWithKeyId(t *testing.T) { @@ -927,7 +925,7 @@ func TestStatusWorkerWithKeyId(t *testing.T) { s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce) require.NotNil(t, s) - connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId) + connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId) require.NoError(t, err) cases := []struct { diff --git a/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql b/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql index 015ffe9043..62db5118df 100644 --- a/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql +++ b/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql @@ -219,6 +219,7 @@ begin; create trigger insert_new_session_state after insert on session for each row execute procedure insert_new_session_state(); + -- Updated in 90/01_remove_session_connection_state -- update_connection_state_on_closed_reason() is used in an update insert trigger on the -- session_connection table. it will valiadate that all the session's -- connections are closed, and then insert a state of "closed" in diff --git a/internal/db/schema/migrations/oss/postgres/0/51_connection.up.sql b/internal/db/schema/migrations/oss/postgres/0/51_connection.up.sql index e8039e4a4d..bd067a7a4d 100644 --- a/internal/db/schema/migrations/oss/postgres/0/51_connection.up.sql +++ b/internal/db/schema/migrations/oss/postgres/0/51_connection.up.sql @@ -148,6 +148,7 @@ begin; create trigger default_create_time_column before insert on session_connection for each row execute procedure default_create_time(); + -- Removed in 90/01_remove_session_connection_state.up.sql -- insert_new_connection_state() is used in an after insert trigger on the -- session_connection table. it will insert a state of "authorized" in -- session_connection_state for the new session connection. diff --git a/internal/db/schema/migrations/oss/postgres/15/01_wh_rename_key_columns.up.sql b/internal/db/schema/migrations/oss/postgres/15/01_wh_rename_key_columns.up.sql index 9d6a797e1b..16b9a8e67f 100644 --- a/internal/db/schema/migrations/oss/postgres/15/01_wh_rename_key_columns.up.sql +++ b/internal/db/schema/migrations/oss/postgres/15/01_wh_rename_key_columns.up.sql @@ -404,6 +404,7 @@ begin; drop trigger wh_insert_session_connection_state on session_connection_state; drop function wh_insert_session_connection_state; +-- Updated in 90/01_remove_session_connection_state.up.sql create function wh_insert_session_connection_state() returns trigger as $$ declare diff --git a/internal/db/schema/migrations/oss/postgres/27/01_disable_terminate_session.up.sql b/internal/db/schema/migrations/oss/postgres/27/01_disable_terminate_session.up.sql index f4a60ca424..af0d1451cc 100644 --- a/internal/db/schema/migrations/oss/postgres/27/01_disable_terminate_session.up.sql +++ b/internal/db/schema/migrations/oss/postgres/27/01_disable_terminate_session.up.sql @@ -7,6 +7,7 @@ begin; drop trigger update_connection_state_on_closed_reason on session_connection; drop function update_connection_state_on_closed_reason(); +-- Removed in 90/01_remove_session_connection_state.up.sql create function update_connection_state_on_closed_reason() returns trigger as $$ begin diff --git a/internal/db/schema/migrations/oss/postgres/27/02_wh_session_facts.up.sql b/internal/db/schema/migrations/oss/postgres/27/02_wh_session_facts.up.sql index 324b2e7a0b..98030c45d2 100644 --- a/internal/db/schema/migrations/oss/postgres/27/02_wh_session_facts.up.sql +++ b/internal/db/schema/migrations/oss/postgres/27/02_wh_session_facts.up.sql @@ -7,6 +7,7 @@ begin; drop trigger wh_insert_session_connection on session_connection; drop function wh_insert_session_connection(); +-- Updated in 90/01_remove_session_connection_state create function wh_insert_session_connection() returns trigger as $$ declare diff --git a/internal/db/schema/migrations/oss/postgres/89/01_remove_session_connection_state.up.sql b/internal/db/schema/migrations/oss/postgres/89/01_remove_session_connection_state.up.sql new file mode 100644 index 0000000000..ea4a949632 --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/89/01_remove_session_connection_state.up.sql @@ -0,0 +1,247 @@ +-- Copyright (c) HashiCorp, Inc. +-- SPDX-License-Identifier: BUSL-1.1 + +begin; + + -- Remove the session_connection_state table and any related triggers + drop trigger update_connection_state_on_closed_reason on session_connection; + drop function update_connection_state_on_closed_reason(); + + drop trigger insert_session_connection_state on session_connection_state; + drop function insert_session_connection_state(); + + drop trigger update_session_state_on_termination_reason on session; + drop function update_session_state_on_termination_reason(); + + drop trigger insert_new_connection_state on session_connection; + drop function insert_new_connection_state(); + + drop trigger immutable_columns on session_connection_state; + + drop trigger wh_insert_session_connection_state on session_connection_state; + drop function wh_insert_session_connection_state(); + + drop trigger wh_insert_session_connection on session_connection; + drop function wh_insert_session_connection(); + + -- If the connected_time_range is null, it means the connection is authorized but not connected. + -- If the upper value of connected_time_range is > now() (upper range is infinity) then the state is connected. + -- If the upper value of connected_time_range is <= now() then the connection is closed. + alter table session_connection + add column connected_time_range tstzrange; + + -- Migrate existing data from session_connection_state to session_connection + update session_connection + set connected_time_range = (select tstzrange(min(start_time), max(start_time)) + from session_connection_state + where session_connection_state.connection_id = session_connection.public_id + group by connection_id ); + + drop table session_connection_state; + drop table session_connection_state_enm; + + -- Insert on session_connection creates the connection entry, leaving the connected_time_range to null, indicating the connection is authorized + -- "Connected" is handled by the function ConnectConnection, which sets the connected_time_range lower bound to now() and upper bound to infinity + -- "Closed" is handled by the trigger function, update_connected_time_range_on_closed_reason, which sets the connected_time_range upper bound to now() + -- State transitions are guarded by the trigger function, check_connection_state_transition, which ensures that the state transitions are valid + create function check_connection_state_transition() returns trigger + as $$ + begin + -- If old state was authorized, allow transition to connected or closed + if old.connected_time_range is null then + return new; + end if; + + -- If old state was closed, no transitions are allowed + if upper(old.connected_time_range) < 'infinity' and old.connected_time_range != new.connected_time_range then + raise exception 'Invalid state transition from closed'; + end if; + + -- If old state was connected, allow transition to closed + if upper(old.connected_time_range) = 'infinity' and + upper(new.connected_time_range) != 'infinity' and + lower(old.connected_time_range) = lower(new.connected_time_range) then + return new; + else + raise exception 'Invalid state transition from connected'; + end if; + + return new; + end; + $$ language plpgsql; + + create trigger check_connection_state_transition before update of connected_time_range on session_connection + for each row execute procedure check_connection_state_transition(); + + create function update_connected_time_range_on_closed_reason() returns trigger + as $$ + begin + if new.closed_reason is not null then + if old.connected_time_range is null or upper(old.connected_time_range) = 'infinity'::timestamptz then + new.connected_time_range = tstzrange(lower(old.connected_time_range), now(), '[]'); + end if; + end if; + return new; + end; + $$ language plpgsql; + + create trigger update_connected_time_range_closed_reason before update of closed_reason on session_connection + for each row execute procedure update_connected_time_range_on_closed_reason(); + + create function update_session_state_on_termination_reason() returns trigger + as $$ + begin + if new.termination_reason is not null then + perform + from session_connection + where session_id = new.public_id + and upper(connected_time_range) = 'infinity'::timestamptz; + if found then + raise 'session %s has open connections', new.public_id; + end if; + -- check to see if there's a terminated state already, before inserting a + -- new one. + perform + from session_state ss + where ss.session_id = new.public_id and + ss.state = 'terminated'; + if found then + return new; + end if; + insert into session_state (session_id, state) + values (new.public_id, 'terminated'); + end if; + return new; + end; + $$ language plpgsql; + + create trigger update_session_state_on_termination_reason after update of termination_reason on session + for each row execute procedure update_session_state_on_termination_reason(); + + create function wh_insert_session_connection() returns trigger + as $$ + declare + new_row wh_session_connection_accumulating_fact%rowtype; + begin + with + authorized_timestamp (date_dim_key, time_dim_key, ts) as ( + select wh_date_key(create_time), wh_time_key(create_time), create_time + from session_connection + where public_id = new.public_id + and connected_time_range is null + ), + session_dimension (host_dim_key, user_dim_key, credential_group_dim_key) as ( + select host_key, user_key, credential_group_key + from wh_session_accumulating_fact + where session_id = new.session_id + ) + insert into wh_session_connection_accumulating_fact ( + connection_id, + session_id, + host_key, + user_key, + credential_group_key, + connection_authorized_date_key, + connection_authorized_time_key, + connection_authorized_time, + client_tcp_address, + client_tcp_port_number, + endpoint_tcp_address, + endpoint_tcp_port_number, + bytes_up, + bytes_down + ) + select new.public_id, + new.session_id, + session_dimension.host_dim_key, + session_dimension.user_dim_key, + session_dimension.credential_group_dim_key, + authorized_timestamp.date_dim_key, + authorized_timestamp.time_dim_key, + authorized_timestamp.ts, + new.client_tcp_address, + new.client_tcp_port, + new.endpoint_tcp_address, + new.endpoint_tcp_port, + new.bytes_up, + new.bytes_down + from authorized_timestamp, + session_dimension + returning * into strict new_row; + return null; + end; + $$ language plpgsql; + + create trigger wh_insert_session_connection after insert on session_connection + for each row execute function wh_insert_session_connection(); + + create function wh_insert_session_connection_state() returns trigger + as $$ + declare + state text; + date_col text; + time_col text; + ts_col text; + q text; + connection_row wh_session_connection_accumulating_fact%rowtype; + begin + if new.connected_time_range is null then + -- Indicates authorized connection. The update statement in this + -- trigger will fail for the authorized state because the row for the + -- session connection has not yet been inserted into the + -- wh_session_connection_accumulating_fact table. + return null; + end if; + + if upper(new.connected_time_range) = 'infinity'::timestamptz then + update wh_session_connection_accumulating_fact + set (connection_connected_date_key, + connection_connected_time_key, + connection_connected_time) = (select wh_date_key(new.update_time), + wh_time_key(new.update_time), + new.update_time::timestamptz) + where connection_id = new.public_id; + else + update wh_session_connection_accumulating_fact + set (connection_closed_date_key, + connection_closed_time_key, + connection_closed_time) = (select wh_date_key(new.update_time), + wh_time_key(new.update_time), + new.update_time::timestamptz) + where connection_id = new.public_id; + end if; + + return null; + end; + $$ language plpgsql; + + create trigger wh_insert_session_connection_state after update of connected_time_range on session_connection + for each row execute function wh_insert_session_connection_state(); + + create view session_connection_with_status_view as + select public_id, + session_id, + client_tcp_address, + client_tcp_port, + endpoint_tcp_address, + endpoint_tcp_port, + bytes_up, + bytes_down, + closed_reason, + version, + create_time, + update_time, + user_client_ip, + worker_id, + case + when connected_time_range is null then 'authorized' + when upper(connected_time_range) > now() then 'connected' + else 'closed' + end as status + from session_connection; + + create index connected_time_range_idx on session_connection (connected_time_range); + + create index connected_time_range_upper_idx on session_connection (upper(connected_time_range)); + +commit; \ No newline at end of file diff --git a/internal/db/sqltest/tests/session/session_connection_state_transition.sql b/internal/db/sqltest/tests/session/session_connection_state_transition.sql new file mode 100644 index 0000000000..61eb6bbb7a --- /dev/null +++ b/internal/db/sqltest/tests/session/session_connection_state_transition.sql @@ -0,0 +1,35 @@ +-- Copyright (c) HashiCorp, Inc. +-- SPDX-License-Identifier: BUSL-1.1 + +begin; + select plan(6); + + -- Ensure session connection table is populated + select is(count(*), 2::bigint) from session_connection; + + -- Check that both session connections are in the authorized state (null connected_time_range) + select is(count(*), 2::bigint) from session_connection where connected_time_range is null; + + -- Connect one of the session connections + update session_connection + set connected_time_range=tstzrange(now(),'infinity') + where public_id = 's1c1___clare'; + select is(count(*), 1::bigint) from session_connection where upper(connected_time_range) > now(); + + -- Close the other session connection + update session_connection + set closed_reason = 'unknown' + where public_id = 's2c1___clare'; + select is(count(*), 1::bigint) from session_connection where upper(connected_time_range) <= now(); + + -- Attempt to connect the closed session connection, expect an error + select throws_ok($$ update session_connection + set connected_time_range = tstzrange(now(), 'infinity') + where public_id = 's2c1___clare'$$); + + -- Still only 1 connected session + select is(count(*), 1::bigint) from session_connection where upper(connected_time_range) > now(); + + select * from finish(); +rollback; + diff --git a/internal/db/sqltest/tests/wh/session_connection/update.sql b/internal/db/sqltest/tests/wh/session_connection/update.sql index 0b23f94ec6..ae40f12d1e 100644 --- a/internal/db/sqltest/tests/wh/session_connection/update.sql +++ b/internal/db/sqltest/tests/wh/session_connection/update.sql @@ -12,7 +12,8 @@ begin; update session_connection set bytes_up = 10, bytes_down = 5, - closed_reason = 'closed by end-user' + closed_reason = 'closed by end-user', + connected_time_range = tstzrange(now()::wh_timestamp, now()::wh_timestamp) where public_id = 's1c1___clare'; select is(count(*), 2::bigint) from wh_session_connection_accumulating_fact; diff --git a/internal/server/repository_worker_test.go b/internal/server/repository_worker_test.go index 89c69671fe..3ef5177524 100644 --- a/internal/server/repository_worker_test.go +++ b/internal/server/repository_worker_test.go @@ -206,10 +206,10 @@ func TestLookupWorker(t *testing.T) { sess := session.TestSession(t, conn, wrapper, composedOf, session.WithDbOpts(db.WithSkipVetForWrite(true)), session.WithExpirationTime(exp)) sess, _, err = sessRepo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, []byte("foo")) require.NoError(t, err) - c, _, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), w.GetPublicId()) + c, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), w.GetPublicId()) require.NoError(t, err) require.NotNil(t, c) - c, _, err = connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), w.GetPublicId()) + c, err = connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), w.GetPublicId()) require.NoError(t, err) require.NotNil(t, c) } @@ -220,7 +220,7 @@ func TestLookupWorker(t *testing.T) { session.WithDbOpts(db.WithSkipVetForWrite(true))) sess2, _, err = sessRepo.ActivateSession(ctx, sess2.GetPublicId(), sess2.Version, []byte("foo")) require.NoError(t, err) - c, _, err := connRepo.AuthorizeConnection(ctx, sess2.GetPublicId(), w.GetPublicId()) + c, err := connRepo.AuthorizeConnection(ctx, sess2.GetPublicId(), w.GetPublicId()) require.NoError(t, err) require.NotNil(t, c) } diff --git a/internal/session/connection.go b/internal/session/connection.go index 9d9d798e38..b4697ad7f9 100644 --- a/internal/session/connection.go +++ b/internal/session/connection.go @@ -13,7 +13,7 @@ import ( ) const ( - defaultConnectionTableName = "session_connection" + defaultConnectionTableName = "session_connection_with_status_view" // "session_connection" ) // Connection contains information about session's connection to a target @@ -44,6 +44,8 @@ type Connection struct { UpdateTime *timestamp.Timestamp `json:"update_time,omitempty" gorm:"default:current_timestamp"` // Version of the connection Version uint32 `json:"version,omitempty" gorm:"default:null"` + // Status is a field derived from connected_time_range + Status string `json:"status,omitempty" gorm:"default:null"` tableName string `gorm:"-"` } @@ -94,6 +96,7 @@ func (c *Connection) Clone() any { BytesDown: c.BytesDown, ClosedReason: c.ClosedReason, Version: c.Version, + Status: c.Status, } if c.CreateTime != nil { clone.CreateTime = ×tamp.Timestamp{ diff --git a/internal/session/connection_state.go b/internal/session/connection_state.go index d2d3445a88..eda5bcd014 100644 --- a/internal/session/connection_state.go +++ b/internal/session/connection_state.go @@ -4,20 +4,9 @@ package session import ( - "context" - - "github.com/hashicorp/boundary/internal/db" - "github.com/hashicorp/boundary/internal/db/timestamp" - "github.com/hashicorp/boundary/internal/errors" - "google.golang.org/protobuf/types/known/timestamppb" - workerpbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" ) -const ( - defaultConnectionStateTableName = "session_connection_state" -) - // ConnectionStatus of the connection's state type ConnectionStatus string @@ -60,122 +49,14 @@ func ConnectionStatusFromProtoVal(s workerpbs.CONNECTIONSTATUS) ConnectionStatus return StatusUnspecified } -// ConnectionState of the state of the connection -type ConnectionState struct { - // ConnectionId is used to access the state via an API - ConnectionId string `json:"public_id,omitempty" gorm:"primary_key"` - // status of the connection - Status ConnectionStatus `protobuf:"bytes,20,opt,name=status,proto3" json:"status,omitempty" gorm:"column:state"` - // PreviousEndTime from the RDBMS - PreviousEndTime *timestamp.Timestamp `json:"previous_end_time,omitempty" gorm:"default:current_timestamp"` - // StartTime from the RDBMS - StartTime *timestamp.Timestamp `json:"start_time,omitempty" gorm:"default:current_timestamp;primary_key"` - // EndTime from the RDBMS - EndTime *timestamp.Timestamp `json:"end_time,omitempty" gorm:"default:current_timestamp"` - - tableName string `gorm:"-"` -} - -var ( - _ Cloneable = (*ConnectionState)(nil) - _ db.VetForWriter = (*ConnectionState)(nil) -) - -// NewConnectionState creates a new in memory connection state. No options -// are currently supported. -func NewConnectionState(ctx context.Context, connectionId string, state ConnectionStatus, _ ...Option) (*ConnectionState, error) { - const op = "session.NewConnectionState" - s := ConnectionState{ - ConnectionId: connectionId, - Status: state, - } - if err := s.validate(ctx); err != nil { - return nil, errors.Wrap(ctx, err, op) - } - return &s, nil -} - -// allocConnectionState will allocate a connection State -func allocConnectionState() ConnectionState { - return ConnectionState{} -} - -// Clone creates a clone of the State -func (s *ConnectionState) Clone() any { - clone := &ConnectionState{ - ConnectionId: s.ConnectionId, - Status: s.Status, - } - if s.PreviousEndTime != nil { - clone.PreviousEndTime = ×tamp.Timestamp{ - Timestamp: ×tamppb.Timestamp{ - Seconds: s.PreviousEndTime.Timestamp.Seconds, - Nanos: s.PreviousEndTime.Timestamp.Nanos, - }, - } - } - - if s.StartTime != nil { - clone.StartTime = ×tamp.Timestamp{ - Timestamp: ×tamppb.Timestamp{ - Seconds: s.StartTime.Timestamp.Seconds, - Nanos: s.StartTime.Timestamp.Nanos, - }, - } - } - if s.EndTime != nil { - clone.EndTime = ×tamp.Timestamp{ - Timestamp: ×tamppb.Timestamp{ - Seconds: s.EndTime.Timestamp.Seconds, - Nanos: s.EndTime.Timestamp.Nanos, - }, - } - } - return clone -} - -// VetForWrite implements db.VetForWrite() interface and validates the state -// before it's written. -func (s *ConnectionState) VetForWrite(ctx context.Context, _ db.Reader, _ db.OpType, _ ...db.Option) error { - const op = "session.(ConnectionState).VetForWrite" - if err := s.validate(ctx); err != nil { - return errors.Wrap(ctx, err, op) - } - return nil -} - -// TableName returns the tablename to override the default gorm table name -func (s *ConnectionState) TableName() string { - if s.tableName != "" { - return s.tableName - } - return defaultConnectionStateTableName -} - -// SetTableName sets the tablename and satisfies the ReplayableMessage -// interface. If the caller attempts to set the name to "" the name will be -// reset to the default name. -func (s *ConnectionState) SetTableName(n string) { - s.tableName = n -} - -// validate checks the session state -func (s *ConnectionState) validate(ctx context.Context) error { - const op = "session.(ConnectionState).validate" - if s.Status == "" { - return errors.New(ctx, errors.InvalidParameter, op, "missing status") - } - if s.ConnectionId == "" { - return errors.New(ctx, errors.InvalidParameter, op, "missing connection id") - } - if s.StartTime != nil { - return errors.New(ctx, errors.InvalidParameter, op, "start time is not settable") - } - if s.EndTime != nil { - return errors.New(ctx, errors.InvalidParameter, op, "end time is not settable") - } - if s.PreviousEndTime != nil { - return errors.New(ctx, errors.InvalidParameter, op, "previous end time is not settable") +func ConnectionStatusFromString(s string) ConnectionStatus { + switch s { + case "authorized": + return StatusAuthorized + case "connected": + return StatusConnected + case "closed": + return StatusClosed } - return nil + return StatusUnspecified } diff --git a/internal/session/connection_state_test.go b/internal/session/connection_state_test.go deleted file mode 100644 index 5310b35992..0000000000 --- a/internal/session/connection_state_test.go +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -package session - -import ( - "context" - "testing" - - "github.com/hashicorp/boundary/internal/db" - "github.com/hashicorp/boundary/internal/errors" - "github.com/hashicorp/boundary/internal/iam" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestConnectionState_Create(t *testing.T) { - t.Parallel() - ctx := context.Background() - conn, _ := db.TestSetup(t, "postgres") - wrapper := db.TestWrapper(t) - iamRepo := iam.TestRepo(t, conn, wrapper) - session := TestDefaultSession(t, conn, wrapper, iamRepo) - connection := TestConnection(t, conn, session.PublicId, "127.0.0.1", 443, "127.0.0.1", 4443, "127.0.0.1") - - type args struct { - connectionId string - status ConnectionStatus - } - tests := []struct { - name string - args args - want *ConnectionState - wantErr bool - wantIsErr errors.Code - create bool - wantCreateErr bool - }{ - { - name: "valid", - args: args{ - connectionId: connection.PublicId, - status: StatusClosed, - }, - want: &ConnectionState{ - ConnectionId: connection.PublicId, - Status: StatusClosed, - }, - create: true, - }, - { - name: "empty-connectionId", - args: args{ - status: StatusClosed, - }, - wantErr: true, - wantIsErr: errors.InvalidParameter, - }, - { - name: "empty-status", - args: args{ - connectionId: connection.PublicId, - }, - wantErr: true, - wantIsErr: errors.InvalidParameter, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - got, err := NewConnectionState(ctx, tt.args.connectionId, tt.args.status) - if tt.wantErr { - require.Error(err) - assert.True(errors.Match(errors.T(tt.wantIsErr), err)) - return - } - require.NoError(err) - assert.Equal(tt.want, got) - if tt.create { - err = db.New(conn).Create(ctx, got) - if tt.wantCreateErr { - assert.Error(err) - return - } else { - assert.NoError(err) - } - } - }) - } -} - -func TestConnectionState_Delete(t *testing.T) { - t.Parallel() - ctx := context.Background() - conn, _ := db.TestSetup(t, "postgres") - rw := db.New(conn) - wrapper := db.TestWrapper(t) - iamRepo := iam.TestRepo(t, conn, wrapper) - s := TestDefaultSession(t, conn, wrapper, iamRepo) - c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") - c2 := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") - - tests := []struct { - name string - state *ConnectionState - deleteConnectionStateId string - wantRowsDeleted int - wantErr bool - wantErrMsg string - }{ - { - name: "valid", - state: TestConnectionState(t, conn, c.PublicId, StatusClosed), - wantErr: false, - wantRowsDeleted: 1, - }, - { - name: "bad-id", - state: TestConnectionState(t, conn, c2.PublicId, StatusClosed), - deleteConnectionStateId: func() string { - id, err := db.NewPublicId(ctx, ConnectionStatePrefix) - require.NoError(t, err) - return id - }(), - wantErr: false, - wantRowsDeleted: 0, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - - var initialState ConnectionState - err := rw.LookupWhere(context.Background(), &initialState, "connection_id = ? and state = ?", []any{tt.state.ConnectionId, tt.state.Status}) - require.NoError(err) - - deleteState := allocConnectionState() - if tt.deleteConnectionStateId != "" { - deleteState.ConnectionId = tt.deleteConnectionStateId - } else { - deleteState.ConnectionId = tt.state.ConnectionId - } - deleteState.StartTime = initialState.StartTime - deletedRows, err := rw.Delete(ctx, &deleteState) - if tt.wantErr { - require.Error(err) - return - } - require.NoError(err) - if tt.wantRowsDeleted == 0 { - assert.Equal(tt.wantRowsDeleted, deletedRows) - return - } - assert.Equal(tt.wantRowsDeleted, deletedRows) - foundState := allocConnectionState() - err = rw.LookupWhere(ctx, &foundState, "connection_id = ? and start_time = ?", []any{tt.state.ConnectionId, initialState.StartTime}) - require.Error(err) - assert.True(errors.IsNotFoundError(err)) - }) - } -} - -func TestConnectionState_Clone(t *testing.T) { - t.Parallel() - conn, _ := db.TestSetup(t, "postgres") - wrapper := db.TestWrapper(t) - iamRepo := iam.TestRepo(t, conn, wrapper) - t.Run("valid", func(t *testing.T) { - assert := assert.New(t) - s := TestDefaultSession(t, conn, wrapper, iamRepo) - c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") - state := TestConnectionState(t, conn, c.PublicId, StatusConnected) - cp := state.Clone() - assert.Equal(cp.(*ConnectionState), state) - }) - t.Run("not-equal", func(t *testing.T) { - assert := assert.New(t) - s := TestDefaultSession(t, conn, wrapper, iamRepo) - c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") - state := TestConnectionState(t, conn, c.PublicId, StatusConnected) - state2 := TestConnectionState(t, conn, c.PublicId, StatusConnected) - - cp := state.Clone() - assert.NotEqual(cp.(*ConnectionState), state2) - }) -} - -func TestConnectionState_SetTableName(t *testing.T) { - t.Parallel() - defaultTableName := defaultConnectionStateTableName - tests := []struct { - name string - setNameTo string - want string - }{ - { - name: "new-name", - setNameTo: "new-name", - want: "new-name", - }, - { - name: "reset to default", - setNameTo: "", - want: defaultTableName, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - def := allocConnectionState() - require.Equal(defaultTableName, def.TableName()) - s := allocConnectionState() - s.SetTableName(tt.setNameTo) - assert.Equal(tt.want, s.TableName()) - }) - } -} diff --git a/internal/session/immutable_fields_test.go b/internal/session/immutable_fields_test.go index fc0baee445..d5efdbe497 100644 --- a/internal/session/immutable_fields_test.go +++ b/internal/session/immutable_fields_test.go @@ -9,7 +9,6 @@ import ( "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" - "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/iam" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -236,96 +235,3 @@ func TestConnection_ImmutableFields(t *testing.T) { }) } } - -func TestConnectionState_ImmutableFields(t *testing.T) { - t.Parallel() - conn, _ := db.TestSetup(t, "postgres") - rw := db.New(conn) - wrapper := db.TestWrapper(t) - iamRepo := iam.TestRepo(t, conn, wrapper) - - ts := timestamp.Timestamp{Timestamp: ×tamppb.Timestamp{Seconds: 0, Nanos: 0}} - - _, _ = iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) - session := TestDefaultSession(t, conn, wrapper, iamRepo) - connection := TestConnection(t, conn, session.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") - state := TestConnectionState(t, conn, connection.PublicId, StatusConnected) - - var new ConnectionState - err := rw.LookupWhere(context.Background(), &new, "connection_id = ? and state = ?", []any{state.ConnectionId, state.Status}) - require.NoError(t, err) - - tests := []struct { - name string - update *ConnectionState - fieldMask []string - wantErrMatch *errors.Template - wantErrContains string - }{ - { - name: "session_id", - update: func() *ConnectionState { - s := new.Clone().(*ConnectionState) - s.ConnectionId = "sc_thisIsNotAValidId" - return s - }(), - fieldMask: []string{"PublicId"}, - }, - { - name: "status", - update: func() *ConnectionState { - s := new.Clone().(*ConnectionState) - s.Status = "closed" - return s - }(), - fieldMask: []string{"Status"}, - wantErrMatch: errors.T(errors.NotSpecificIntegrity), - wantErrContains: "immutable column", - }, - { - name: "start time", - update: func() *ConnectionState { - s := new.Clone().(*ConnectionState) - s.StartTime = &ts - return s - }(), - fieldMask: []string{"StartTime"}, - wantErrMatch: errors.T(errors.InvalidFieldMask), - wantErrContains: "parameter violation", - }, - { - name: "previous_end_time", - update: func() *ConnectionState { - s := new.Clone().(*ConnectionState) - s.PreviousEndTime = &ts - return s - }(), - fieldMask: []string{"PreviousEndTime"}, - wantErrMatch: errors.T(errors.NotSpecificIntegrity), - wantErrContains: "immutable column", - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - orig := new.Clone() - err := rw.LookupWhere(context.Background(), orig, "connection_id = ? and start_time = ?", []any{new.ConnectionId, new.StartTime}) - require.NoError(err) - - rowsUpdated, err := rw.Update(context.Background(), tt.update, tt.fieldMask, nil, db.WithSkipVetForWrite(true)) - require.Error(err) - assert.Equal(0, rowsUpdated) - if tt.wantErrMatch != nil { - assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted error %s and got: %s", tt.wantErrMatch.Code, err.Error()) - } - if tt.wantErrContains != "" { - assert.Contains(err.Error(), tt.wantErrContains) - } - after := new.Clone() - err = rw.LookupWhere(context.Background(), after, "connection_id = ? and start_time = ?", []any{new.ConnectionId, new.StartTime}) - require.NoError(err) - assert.Equal(orig.(*ConnectionState), after) - }) - } -} diff --git a/internal/session/job_session_cleanup_test.go b/internal/session/job_session_cleanup_test.go index 13becf67c2..c98459778f 100644 --- a/internal/session/job_session_cleanup_test.go +++ b/internal/session/job_session_cleanup_test.go @@ -65,10 +65,9 @@ func TestSessionConnectionCleanupJob(t *testing.T) { sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithDbOpts(db.WithSkipVetForWrite(true))) sess, _, err = sessionRepo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, []byte("foo")) require.NoError(err) - c, cs, _, err := AuthorizeConnection(ctx, sessionRepo, connectionRepo, sess.GetPublicId(), serverId) + c, _, err := AuthorizeConnection(ctx, sessionRepo, connectionRepo, sess.GetPublicId(), serverId) require.NoError(err) - require.Len(cs, 1) - require.Equal(StatusAuthorized, cs[0].Status) + require.Equal(StatusAuthorized, ConnectionStatusFromString(c.Status)) connIds = append(connIds, c.GetPublicId()) if i%2 == 0 { connIdsByWorker[worker2.PublicId] = append(connIdsByWorker[worker2.PublicId], c.GetPublicId()) @@ -81,7 +80,7 @@ func TestSessionConnectionCleanupJob(t *testing.T) { // This is just to ensure we have a spread when we test it out. for i, connId := range connIds { if i%2 == 0 { - _, cs, err := connectionRepo.ConnectConnection(ctx, ConnectWith{ + cc, err := connectionRepo.ConnectConnection(ctx, ConnectWith{ ConnectionId: connId, ClientTcpAddress: "127.0.0.1", ClientTcpPort: 22, @@ -90,18 +89,7 @@ func TestSessionConnectionCleanupJob(t *testing.T) { UserClientIp: "127.0.0.1", }) require.NoError(err) - require.Len(cs, 2) - var foundAuthorized, foundConnected bool - for _, status := range cs { - if status.Status == StatusAuthorized { - foundAuthorized = true - } - if status.Status == StatusConnected { - foundConnected = true - } - } - require.True(foundAuthorized) - require.True(foundConnected) + require.Equal(StatusConnected, ConnectionStatusFromString(cc.Status)) } } @@ -130,14 +118,11 @@ func TestSessionConnectionCleanupJob(t *testing.T) { require.True(ok) require.Len(connIds, 6) for _, connId := range connIds { - _, states, err := connectionRepo.LookupConnection(ctx, connId) + conn, err := connectionRepo.LookupConnection(ctx, connId) require.NoError(err) var foundClosed bool - for _, state := range states { - if state.Status == StatusClosed { - foundClosed = true - break - } + if ConnectionStatusFromString(conn.Status) == StatusClosed { + foundClosed = true } assert.Equal(closed, foundClosed) } @@ -239,10 +224,9 @@ func TestCloseConnectionsForDeadWorkers(t *testing.T) { sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithDbOpts(db.WithSkipVetForWrite(true))) sess, _, err = repo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, []byte("foo")) require.NoError(err) - c, cs, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) + c, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) require.NoError(err) - require.Len(cs, 1) - require.Equal(StatusAuthorized, cs[0].Status) + require.Equal(StatusAuthorized, ConnectionStatusFromString(c.Status)) if i%3 == 0 { worker1ConnIds = append(worker1ConnIds, c.GetPublicId()) } else if i%3 == 1 { @@ -263,7 +247,7 @@ func TestCloseConnectionsForDeadWorkers(t *testing.T) { return s }() { if i%3 == 0 { - _, cs, err := connRepo.ConnectConnection(ctx, ConnectWith{ + cc, err := connRepo.ConnectConnection(ctx, ConnectWith{ ConnectionId: connId, ClientTcpAddress: "127.0.0.1", ClientTcpPort: 22, @@ -272,18 +256,7 @@ func TestCloseConnectionsForDeadWorkers(t *testing.T) { UserClientIp: "127.0.0.1", }) require.NoError(err) - require.Len(cs, 2) - var foundAuthorized, foundConnected bool - for _, status := range cs { - if status.Status == StatusAuthorized { - foundAuthorized = true - } - if status.Status == StatusConnected { - foundConnected = true - } - } - require.True(foundAuthorized) - require.True(foundConnected) + require.Equal(StatusConnected, ConnectionStatusFromString(cc.Status)) } else if i%3 == 1 { resp, err := connRepo.closeConnections(ctx, []CloseWith{ { @@ -293,19 +266,8 @@ func TestCloseConnectionsForDeadWorkers(t *testing.T) { }) require.NoError(err) require.Len(resp, 1) - cs := resp[0].ConnectionStates - require.Len(cs, 2) - var foundAuthorized, foundClosed bool - for _, status := range cs { - if status.Status == StatusAuthorized { - foundAuthorized = true - } - if status.Status == StatusClosed { - foundClosed = true - } - } - require.True(foundAuthorized) - require.True(foundClosed) + cs := resp[0].ConnectionState + require.Equal(StatusClosed, cs) } } @@ -344,9 +306,9 @@ func TestCloseConnectionsForDeadWorkers(t *testing.T) { expected = StatusAuthorized } - _, states, err := connRepo.LookupConnection(ctx, connId) + conn, err := connRepo.LookupConnection(ctx, connId) require.NoError(err) - require.Equal(expected, states[0].Status, "expected latest status for %q (index %d) to be %v", connId, i, expected) + require.Equal(expected, ConnectionStatusFromString(conn.Status), "expected latest status for %q (index %d) to be %v", connId, i, expected) } } @@ -480,10 +442,9 @@ func TestCloseWorkerlessConnections(t *testing.T) { sess, _, err = repo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, []byte("foo")) require.NoError(err) - conn, cs, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), workerId) + conn, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), workerId) require.NoError(err) - require.Len(cs, 1) - require.Equal(StatusAuthorized, cs[0].Status) + require.Equal(StatusAuthorized, ConnectionStatusFromString(conn.Status)) return conn } @@ -513,21 +474,21 @@ func TestCloseWorkerlessConnections(t *testing.T) { }}) require.NoError(err) - _, st, err := connRepo.LookupConnection(ctx, dActiveConn.GetPublicId()) + con, err := connRepo.LookupConnection(ctx, dActiveConn.GetPublicId()) require.NoError(err) - require.Equal(StatusAuthorized, st[0].Status) + require.Equal(StatusAuthorized, ConnectionStatusFromString(con.Status)) - _, st, err = connRepo.LookupConnection(ctx, dClosedConn.GetPublicId()) + con, err = connRepo.LookupConnection(ctx, dClosedConn.GetPublicId()) require.NoError(err) - require.Equal(StatusClosed, st[0].Status) + require.Equal(StatusClosed, ConnectionStatusFromString(con.Status)) - _, st, err = connRepo.LookupConnection(ctx, activeConn.GetPublicId()) + con, err = connRepo.LookupConnection(ctx, activeConn.GetPublicId()) require.NoError(err) - require.Equal(StatusAuthorized, st[0].Status) + require.Equal(StatusAuthorized, ConnectionStatusFromString(con.Status)) - _, st, err = connRepo.LookupConnection(ctx, closedConn.GetPublicId()) + con, err = connRepo.LookupConnection(ctx, closedConn.GetPublicId()) require.NoError(err) - require.Equal(StatusClosed, st[0].Status) + require.Equal(StatusClosed, ConnectionStatusFromString(con.Status)) // Run the job numClosed, err := job.closeWorkerlessConnections(ctx) @@ -535,19 +496,19 @@ func TestCloseWorkerlessConnections(t *testing.T) { assert.Equal(t, 1, numClosed) // This is the only one that the job should have actually closed. - _, st, err = connRepo.LookupConnection(ctx, dActiveConn.GetPublicId()) + con, err = connRepo.LookupConnection(ctx, dActiveConn.GetPublicId()) require.NoError(err) - require.Equal(StatusClosed, st[0].Status) + require.Equal(StatusClosed, ConnectionStatusFromString(con.Status)) - _, st, err = connRepo.LookupConnection(ctx, dClosedConn.GetPublicId()) + con, err = connRepo.LookupConnection(ctx, dClosedConn.GetPublicId()) require.NoError(err) - require.Equal(StatusClosed, st[0].Status) + require.Equal(StatusClosed, ConnectionStatusFromString(con.Status)) - _, st, err = connRepo.LookupConnection(ctx, activeConn.GetPublicId()) + con, err = connRepo.LookupConnection(ctx, activeConn.GetPublicId()) require.NoError(err) - require.Equal(StatusAuthorized, st[0].Status) + require.Equal(StatusAuthorized, ConnectionStatusFromString(con.Status)) - _, st, err = connRepo.LookupConnection(ctx, closedConn.GetPublicId()) + con, err = connRepo.LookupConnection(ctx, closedConn.GetPublicId()) require.NoError(err) - require.Equal(StatusClosed, st[0].Status) + require.Equal(StatusClosed, ConnectionStatusFromString(con.Status)) } diff --git a/internal/session/query.go b/internal/session/query.go index b2ceacf77d..da18eef887 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -124,6 +124,14 @@ from session_connection_limit, session_connection_count; ` + // connectConnection sets the connected time range to (now, infinity) to + // indicate the connection is connected. + connectConnection = ` + update session_connection + set connected_time_range=tstzrange(now(),'infinity') + where public_id=@public_id +` + terminateSessionIfPossible = ` -- is terminate_session_id in a canceling state with session_version as ( @@ -197,17 +205,11 @@ from select session_id from - session_connection - where public_id in ( - select - connection_id - from - session_connection_state - where - state != 'closed' and - end_time is null - ) - ) + session_connection + where + upper(connected_time_range) > now() or + connected_time_range is null + ) ` // termSessionUpdate is one stmt that terminates sessions for the following @@ -271,21 +273,12 @@ where ) ) and -- make sure there are no existing connections - us.public_id not in ( - select - session_id - from - session_connection - where public_id in ( - select - connection_id - from - session_connection_state - where - state != 'closed' and - end_time is null - ) -); + us.public_id not in ( + select session_id + from session_connection + where upper(connected_time_range) > now() + or connected_time_range is null + ); ` // closeConnectionsForDeadServersCte finds connections that are: @@ -336,22 +329,20 @@ where and closed_reason is null returning public_id; ` - orphanedConnectionsCte = ` -- Find connections that are not closed so we can reference those IDs with unclosed_connections as ( - select connection_id - from session_connection_state + select public_id + from session_connection where - -- It's the current state - end_time is null - -- Current state isn't closed state - and state in ('authorized', 'connected') + -- It's not closed + upper(connected_time_range) > now() or + connected_time_range is null -- It's not in limbo between when it moved into this state and when -- it started being reported by the worker, which is roughly every -- 2-3 seconds - and start_time < wt_sub_seconds_from_now(@worker_state_delay_seconds) + and update_time < wt_sub_seconds_from_now(@worker_state_delay_seconds) ), connections_to_close as ( select public_id @@ -360,7 +351,7 @@ with -- Related to the worker that just reported to us worker_id = @worker_id -- Only unclosed ones - and public_id in (select connection_id from unclosed_connections) + and public_id in (select public_id from unclosed_connections) -- These are connection IDs that just got reported to us by the given -- worker, so they should not be considered closed. %s diff --git a/internal/session/repository_connection.go b/internal/session/repository_connection.go index 34975d995b..9c6f830e6b 100644 --- a/internal/session/repository_connection.go +++ b/internal/session/repository_connection.go @@ -134,19 +134,18 @@ func (r *ConnectionRepository) updateBytesUpBytesDown(ctx context.Context, conns // If authorization is success, it creates/stores a new connection in the repo // and returns it, along with its states. If the authorization fails, it // an error with Code InvalidSessionState. -func (r *ConnectionRepository) AuthorizeConnection(ctx context.Context, sessionId, workerId string) (*Connection, []*ConnectionState, error) { +func (r *ConnectionRepository) AuthorizeConnection(ctx context.Context, sessionId, workerId string) (*Connection, error) { const op = "session.(ConnectionRepository).AuthorizeConnection" if sessionId == "" { - return nil, nil, errors.Wrap(ctx, status.Error(codes.FailedPrecondition, "missing session id"), op, errors.WithCode(errors.InvalidParameter)) + return nil, errors.Wrap(ctx, status.Error(codes.FailedPrecondition, "missing session id"), op, errors.WithCode(errors.InvalidParameter)) } connectionId, err := newConnectionId(ctx) if err != nil { - return nil, nil, errors.Wrap(ctx, err, op) + return nil, errors.Wrap(ctx, err, op) } connection := AllocConnection() connection.PublicId = connectionId - var connectionStates []*ConnectionState _, err = r.writer.DoTx( ctx, db.StdRetryCnt, @@ -166,31 +165,26 @@ func (r *ConnectionRepository) AuthorizeConnection(ctx context.Context, sessionI if err := reader.LookupById(ctx, &connection); err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for session %s", sessionId))) } - connectionStates, err = fetchConnectionStates(ctx, reader, connectionId, db.WithOrder("start_time desc")) - if err != nil { - return errors.Wrap(ctx, err, op) - } return nil }, ) if err != nil { - return nil, nil, errors.Wrap(ctx, err, op) + return nil, errors.Wrap(ctx, err, op) } - return &connection, connectionStates, nil + return &connection, nil } // LookupConnection will look up a connection in the repository and return the connection -// with its states. If the connection is not found, it will return nil, nil, nil. +// with its state. If the connection is not found, it will return nil, nil. // No options are currently supported. -func (r *ConnectionRepository) LookupConnection(ctx context.Context, connectionId string, _ ...Option) (*Connection, []*ConnectionState, error) { +func (r *ConnectionRepository) LookupConnection(ctx context.Context, connectionId string, _ ...Option) (*Connection, error) { const op = "session.(ConnectionRepository).LookupConnection" if connectionId == "" { - return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing connectionId id") + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing connectionId id") } connection := AllocConnection() connection.PublicId = connectionId - var states []*ConnectionState _, err := r.writer.DoTx( ctx, db.StdRetryCnt, @@ -199,20 +193,16 @@ func (r *ConnectionRepository) LookupConnection(ctx context.Context, connectionI if err := read.LookupById(ctx, &connection); err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for %s", connectionId))) } - var err error - if states, err = fetchConnectionStates(ctx, read, connectionId, db.WithOrder("start_time desc")); err != nil { - return errors.Wrap(ctx, err, op) - } return nil }, ) if err != nil { if errors.IsNotFoundError(err) { - return nil, nil, nil + return nil, nil } - return nil, nil, errors.Wrap(ctx, err, op) + return nil, errors.Wrap(ctx, err, op) } - return &connection, states, nil + return &connection, nil } // ListConnectionsBySessionId will list connections by session ID. Supports the @@ -231,14 +221,13 @@ func (r *ConnectionRepository) ListConnectionsBySessionId(ctx context.Context, s } // ConnectConnection updates a connection in the repo with a state of "connected". -func (r *ConnectionRepository) ConnectConnection(ctx context.Context, c ConnectWith) (*Connection, []*ConnectionState, error) { +func (r *ConnectionRepository) ConnectConnection(ctx context.Context, c ConnectWith) (*Connection, error) { const op = "session.(ConnectionRepository).ConnectConnection" // ConnectWith.validate will check all the fields... if err := c.validate(ctx); err != nil { - return nil, nil, errors.Wrap(ctx, err, op) + return nil, errors.Wrap(ctx, err, op) } var connection Connection - var connectionStates []*ConnectionState _, err := r.writer.DoTx( ctx, db.StdRetryCnt, @@ -266,31 +255,33 @@ func (r *ConnectionRepository) ConnectConnection(ctx context.Context, c ConnectW // return err, which will result in a rollback of the update return errors.New(ctx, errors.MultipleRecords, op, "more than 1 resource would have been updated") } - newState, err := NewConnectionState(ctx, connection.PublicId, StatusConnected) + // Set the lower bound of the connected_time_range to indicate the connection is connected + rowsUpdated, err = w.Exec(ctx, connectConnection, []any{ + sql.Named("public_id", c.ConnectionId), + }) if err != nil { return errors.Wrap(ctx, err, op) } - if err := w.Create(ctx, newState); err != nil { - return errors.Wrap(ctx, err, op) + if rowsUpdated != 1 { + return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("unable to connect connection %s", c.ConnectionId))) } - connectionStates, err = fetchConnectionStates(ctx, reader, c.ConnectionId, db.WithOrder("start_time desc")) - if err != nil { - return errors.Wrap(ctx, err, op) + if err := reader.LookupById(ctx, &connection); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for connection %s", c.ConnectionId))) } return nil }, ) if err != nil { - return nil, nil, errors.Wrap(ctx, err, op) + return nil, errors.Wrap(ctx, err, op) } - return &connection, connectionStates, nil + return &connection, nil } // closeConnectionResp is just a wrapper for the response from CloseConnections. -// It wraps the connection and its states for each connection closed. +// It wraps the connection and its state for each connection closed. type closeConnectionResp struct { - Connection *Connection - ConnectionStates []*ConnectionState + Connection *Connection + ConnectionState ConnectionStatus } // closeConnections set's a connection's state to "closed" in the repo. It's @@ -318,8 +309,8 @@ func (r *ConnectionRepository) closeConnections(ctx context.Context, closeWith [ updateConnection.BytesUp = cw.BytesUp updateConnection.BytesDown = cw.BytesDown updateConnection.ClosedReason = cw.ClosedReason.String() - // updating the ClosedReason will trigger an insert into the - // session_connection_state with a state of closed. + // updating the ClosedReason will trigger the session_connection to set the + // upper limit of connection_time_range to indicate the connection is closed. rowsUpdated, err := w.Update( ctx, &updateConnection, @@ -332,13 +323,9 @@ func (r *ConnectionRepository) closeConnections(ctx context.Context, closeWith [ if rowsUpdated != 1 { return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("%d would have been updated for connection %s", rowsUpdated, cw.ConnectionId)) } - states, err := fetchConnectionStates(ctx, reader, cw.ConnectionId, db.WithOrder("start_time desc")) - if err != nil { - return errors.Wrap(ctx, err, op) - } resp = append(resp, closeConnectionResp{ - Connection: &updateConnection, - ConnectionStates: states, + Connection: &updateConnection, + ConnectionState: StatusClosed, }) } @@ -441,15 +428,3 @@ func (r *ConnectionRepository) closeOrphanedConnections(ctx context.Context, wor } return orphanedConns, nil } - -func fetchConnectionStates(ctx context.Context, r db.Reader, connectionId string, opt ...db.Option) ([]*ConnectionState, error) { - const op = "session.fetchConnectionStates" - var states []*ConnectionState - if err := r.SearchWhere(ctx, &states, "connection_id = ?", []any{connectionId}, opt...); err != nil { - return nil, errors.Wrap(ctx, err, op) - } - if len(states) == 0 { - return nil, nil - } - return states, nil -} diff --git a/internal/session/repository_connection_test.go b/internal/session/repository_connection_test.go index a0f4100868..3a047a9ce4 100644 --- a/internal/session/repository_connection_test.go +++ b/internal/session/repository_connection_test.go @@ -240,7 +240,7 @@ func TestRepository_ConnectConnection(t *testing.T) { t.Run(tt.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) - c, cs, err := connRepo.ConnectConnection(context.Background(), tt.connectWith) + c, err := connRepo.ConnectConnection(context.Background(), tt.connectWith) if tt.wantErr { require.Error(err) assert.Truef(errors.Match(errors.T(tt.wantIsError), err), "unexpected error %s", err.Error()) @@ -248,9 +248,8 @@ func TestRepository_ConnectConnection(t *testing.T) { } require.NoError(err) require.NotNil(c) - require.NotNil(cs) - assert.Equal(StatusConnected, cs[0].Status) - gotConn, _, err := connRepo.LookupConnection(context.Background(), c.PublicId) + assert.Equal(StatusConnected, ConnectionStatusFromString(c.Status)) + gotConn, err := connRepo.LookupConnection(context.Background(), c.PublicId) require.NoError(err) assert.Equal(tt.connectWith.ClientTcpAddress, gotConn.ClientTcpAddress) assert.Equal(tt.connectWith.ClientTcpPort, gotConn.ClientTcpPort) @@ -336,7 +335,7 @@ func TestRepository_DeleteConnection(t *testing.T) { } assert.NoError(err) assert.Equal(tt.wantRowsDeleted, deletedRows) - found, _, err := connRepo.LookupConnection(context.Background(), tt.args.connection.PublicId) + found, err := connRepo.LookupConnection(context.Background(), tt.args.connection.PublicId) assert.NoError(err) assert.Nil(found) @@ -378,10 +377,9 @@ func TestRepository_orphanedConnections(t *testing.T) { sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithDbOpts(db.WithSkipVetForWrite(true))) sess, _, err = repo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, []byte("foo")) require.NoError(err) - c, cs, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) + c, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) require.NoError(err) - require.Len(cs, 1) - require.Equal(StatusAuthorized, cs[0].Status) + require.Equal(StatusAuthorized, ConnectionStatusFromString(c.Status)) connIds = append(connIds, c.GetPublicId()) if i%2 == 0 { worker2ConnIds = append(worker2ConnIds, c.GetPublicId()) @@ -394,7 +392,7 @@ func TestRepository_orphanedConnections(t *testing.T) { // This is just to ensure we have a spread when we test it out. for i, connId := range connIds { if i%2 == 0 { - _, cs, err := connRepo.ConnectConnection(ctx, ConnectWith{ + cc, err := connRepo.ConnectConnection(ctx, ConnectWith{ ConnectionId: connId, ClientTcpAddress: "127.0.0.1", ClientTcpPort: 22, @@ -403,18 +401,7 @@ func TestRepository_orphanedConnections(t *testing.T) { UserClientIp: "127.0.0.1", }) require.NoError(err) - require.Len(cs, 2) - var foundAuthorized, foundConnected bool - for _, status := range cs { - if status.Status == StatusAuthorized { - foundAuthorized = true - } - if status.Status == StatusConnected { - foundConnected = true - } - } - require.True(foundAuthorized) - require.True(foundConnected) + require.Equal(StatusConnected, ConnectionStatusFromString(cc.Status)) } } @@ -517,8 +504,8 @@ func TestRepository_CloseConnections(t *testing.T) { assert.Equal(len(tt.closeWith), len(resp)) for _, r := range resp { require.NotNil(r.Connection) - require.NotNil(r.ConnectionStates) - assert.Equal(StatusClosed, r.ConnectionStates[0].Status) + require.NotNil(r.ConnectionState) + assert.Equal(StatusClosed, r.ConnectionState) } }) } @@ -561,7 +548,7 @@ func TestUpdateBytesUpDown(t *testing.T) { // Assert that the bytes up and down values have been persisted. for i := 0; i < len(conns); i++ { - c, _, err := connRepo.LookupConnection(ctx, conns[i].GetPublicId()) + c, err := connRepo.LookupConnection(ctx, conns[i].GetPublicId()) require.NoError(t, err) require.Equal(t, conns[i].BytesUp, c.BytesUp) @@ -604,10 +591,105 @@ func TestUpdateBytesUpDown(t *testing.T) { // BytesUp and BytesDown values should be set to the old ones. for i := 0; i < len(conns); i++ { - c, _, err := connRepo.LookupConnection(ctx, conns[i].GetPublicId()) + c, err := connRepo.LookupConnection(ctx, conns[i].GetPublicId()) require.NoError(t, err) require.Equal(t, conns[i].BytesUp, c.BytesUp) require.Equal(t, conns[i].BytesDown, c.BytesDown) } } + +func TestRepository_StateTransitions(t *testing.T) { + t.Parallel() + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(ctx, rw, rw, kms) + require.NoError(t, err) + connRepo, err := NewConnectionRepository(ctx, rw, rw, kms) + require.NoError(t, err) + + s := TestDefaultSession(t, conn, wrapper, iamRepo) + tofu := TestTofu(t) + _, _, err = repo.ActivateSession(context.Background(), s.PublicId, s.Version, tofu) + require.NoError(t, err) + + // First connection will transition authorized -> connected -> closed + c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") + cw := ConnectWith{ + ConnectionId: c.PublicId, + ClientTcpAddress: "127.0.0.1", + ClientTcpPort: 22, + EndpointTcpAddress: "127.0.0.1", + EndpointTcpPort: 2222, + UserClientIp: "127.0.0.1", + } + gotConn, err := connRepo.LookupConnection(context.Background(), c.PublicId) + require.NoError(t, err) + require.NotNil(t, gotConn) + require.Equal(t, StatusAuthorized, ConnectionStatusFromString(gotConn.Status)) + + _, err = connRepo.ConnectConnection(context.Background(), cw) + require.NoError(t, err) + + gotConn, err = connRepo.LookupConnection(context.Background(), c.PublicId) + require.NoError(t, err) + require.NotNil(t, gotConn) + require.Equal(t, StatusConnected, ConnectionStatusFromString(gotConn.Status)) + + // Attempt to connect again, expect failure + _, err = connRepo.ConnectConnection(context.Background(), cw) + require.Error(t, err) + require.Contains(t, err.Error(), "Invalid state transition from connected") + + closeWith := CloseWith{ + ConnectionId: c.PublicId, + ClosedReason: ConnectionClosedByUser, + } + resp, err := connRepo.closeConnections(context.Background(), []CloseWith{closeWith}) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, StatusClosed, resp[0].ConnectionState) + + // Second connection will transition from authorized -> closed + c2 := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") + + gotConn, err = connRepo.LookupConnection(context.Background(), c2.PublicId) + require.NoError(t, err) + require.NotNil(t, gotConn) + require.Equal(t, StatusAuthorized, ConnectionStatusFromString(gotConn.Status)) + + closeWith2 := CloseWith{ + ConnectionId: c2.PublicId, + ClosedReason: ConnectionClosedByUser, + } + resp, err = connRepo.closeConnections(context.Background(), []CloseWith{closeWith2}) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, StatusClosed, resp[0].ConnectionState) + gotConn, err = connRepo.LookupConnection(context.Background(), c2.PublicId) + require.NoError(t, err) + require.NotNil(t, gotConn) + require.Equal(t, StatusClosed, ConnectionStatusFromString(gotConn.Status)) + + // Now try to connect it while closed and ensure it can't transition to connected + cw2 := ConnectWith{ + ConnectionId: c2.PublicId, + ClientTcpAddress: "127.0.0.1", + ClientTcpPort: 22, + EndpointTcpAddress: "127.0.0.1", + EndpointTcpPort: 2222, + UserClientIp: "127.0.0.1", + } + _, err = connRepo.ConnectConnection(context.Background(), cw2) + require.Error(t, err) + require.Contains(t, err.Error(), "Invalid state transition from closed") + + gotConn, err = connRepo.LookupConnection(context.Background(), c2.PublicId) + require.NoError(t, err) + require.NotNil(t, gotConn) + require.Equal(t, StatusClosed, ConnectionStatusFromString(gotConn.Status)) +} diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index b8faff00f8..001dbd0a66 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -1149,12 +1149,10 @@ func TestRepository_TerminateCompletedSessions(t *testing.T) { conn, err := connRepo.ListConnectionsBySessionId(context.Background(), found.PublicId) require.NoError(err) for _, sc := range conn { - c, cs, err := connRepo.LookupConnection(context.Background(), sc.PublicId) + c, err := connRepo.LookupConnection(context.Background(), sc.PublicId) require.NoError(err) assert.NotEmpty(c.ClosedReason) - for _, s := range cs { - t.Logf("%s session %s connection state %s at %s", found.PublicId, s.ConnectionId, s.Status, s.EndTime) - } + t.Logf("%s session connection state %s", found.PublicId, c.Status) } } else { t.Logf("not terminated %s has a connection limit of %d", found.PublicId, found.ConnectionLimit) @@ -1162,11 +1160,9 @@ func TestRepository_TerminateCompletedSessions(t *testing.T) { conn, err := connRepo.ListConnectionsBySessionId(context.Background(), found.PublicId) require.NoError(err) for _, sc := range conn { - cs, err := fetchConnectionStates(context.Background(), rw, sc.PublicId) + c, err := connRepo.LookupConnection(context.Background(), sc.PublicId) require.NoError(err) - for _, s := range cs { - t.Logf("%s session %s connection state %s at %s", found.PublicId, s.ConnectionId, s.Status, s.EndTime) - } + t.Logf("%s session connection state %s", found.PublicId, c.Status) } } } diff --git a/internal/session/service_authorize_connection.go b/internal/session/service_authorize_connection.go index fb0e229f36..c5e39288d2 100644 --- a/internal/session/service_authorize_connection.go +++ b/internal/session/service_authorize_connection.go @@ -17,17 +17,17 @@ import ( // If any of these criteria is not met, it returns an error with Code InvalidSessionState. func AuthorizeConnection(ctx context.Context, sessionRepoFn *Repository, connectionRepoFn *ConnectionRepository, sessionId, workerId string, opt ...Option, -) (*Connection, []*ConnectionState, *AuthzSummary, error) { +) (*Connection, *AuthzSummary, error) { const op = "session.AuthorizeConnection" - connection, connectionStates, err := connectionRepoFn.AuthorizeConnection(ctx, sessionId, workerId) + connection, err := connectionRepoFn.AuthorizeConnection(ctx, sessionId, workerId) if err != nil { - return nil, nil, nil, errors.Wrap(ctx, err, op) + return nil, nil, errors.Wrap(ctx, err, op) } authzSummary, err := sessionRepoFn.sessionAuthzSummary(ctx, sessionId) if err != nil { - return nil, nil, nil, errors.Wrap(ctx, err, op) + return nil, nil, errors.Wrap(ctx, err, op) } - return connection, connectionStates, authzSummary, nil + return connection, authzSummary, nil } diff --git a/internal/session/service_authorize_connection_test.go b/internal/session/service_authorize_connection_test.go index 8bb44a36bd..712d46e1a3 100644 --- a/internal/session/service_authorize_connection_test.go +++ b/internal/session/service_authorize_connection_test.go @@ -138,7 +138,7 @@ func TestService_AuthorizeConnection(t *testing.T) { t.Run(tt.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) - c, cs, authzInfo, err := AuthorizeConnection(context.Background(), repo, connRepo, tt.session.PublicId, testServer) + c, authzInfo, err := AuthorizeConnection(context.Background(), repo, connRepo, tt.session.PublicId, testServer) if tt.wantErr { require.Error(err) // TODO (jimlambrt 9/2020): add in tests for errorsIs once we @@ -150,8 +150,8 @@ func TestService_AuthorizeConnection(t *testing.T) { } require.NoError(err) require.NotNil(c) - require.NotNil(cs) - assert.Equal(StatusAuthorized, cs[0].Status) + require.NotNil(c.Status) + assert.Equal(StatusAuthorized, ConnectionStatusFromString(c.Status)) assert.True(authzInfo.ExpirationTime.GetTimestamp().AsTime().Sub(tt.wantAuthzInfo.ExpirationTime.GetTimestamp().AsTime()) < 10*time.Millisecond) tt.wantAuthzInfo.ExpirationTime = authzInfo.ExpirationTime diff --git a/internal/session/service_close_connections_test.go b/internal/session/service_close_connections_test.go index 79bed17936..e1b1f152f6 100644 --- a/internal/session/service_close_connections_test.go +++ b/internal/session/service_close_connections_test.go @@ -131,8 +131,8 @@ func TestServiceCloseConnections(t *testing.T) { for _, r := range resp { require.NotNil(r.Connection) - require.NotNil(r.ConnectionStates) - assert.Equal(StatusClosed, r.ConnectionStates[0].Status) + require.NotNil(r.ConnectionState) + assert.Equal(StatusClosed, r.ConnectionState) } // Ensure session is in the state we want- terminated if all conns closed, else active diff --git a/internal/session/service_worker_status_report_test.go b/internal/session/service_worker_status_report_test.go index 1f32e5be48..038a1ad2a0 100644 --- a/internal/session/service_worker_status_report_test.go +++ b/internal/session/service_worker_status_report_test.go @@ -94,7 +94,7 @@ func TestWorkerStatusReport(t *testing.T) { require.NoError(t, err) require.NoError(t, err) - _, _, err = connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId) + _, err = connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId) require.NoError(t, err) _, err = repo.CancelSession(ctx, sess.PublicId, sess.Version) @@ -126,7 +126,7 @@ func TestWorkerStatusReport(t *testing.T) { require.NoError(t, err) require.NoError(t, err) - connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId) + connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId) require.NoError(t, err) return testCase{ worker: worker, @@ -160,7 +160,7 @@ func TestWorkerStatusReport(t *testing.T) { tofu := session.TestTofu(t) sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, tofu) require.NoError(t, err) - connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId) + connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId) require.NoError(t, err) _, err = repo.CancelSession(ctx, sess.PublicId, sess.Version) require.NoError(t, err) @@ -224,7 +224,7 @@ func TestWorkerStatusReport(t *testing.T) { tofu := session.TestTofu(t) sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, tofu) require.NoError(t, err) - connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId) + connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId) require.NoError(t, err) _, err = repo.CancelSession(ctx, sess.PublicId, sess.Version) require.NoError(t, err) @@ -242,7 +242,7 @@ func TestWorkerStatusReport(t *testing.T) { tofu2 := session.TestTofu(t) sess2, _, err = repo.ActivateSession(ctx, sess2.PublicId, sess2.Version, tofu2) require.NoError(t, err) - connection2, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId) + connection2, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId) require.NoError(t, err) _, err = repo.CancelSession(ctx, sess2.PublicId, sess2.Version) require.NoError(t, err) @@ -295,7 +295,7 @@ func TestWorkerStatusReport(t *testing.T) { tofu := session.TestTofu(t) sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, tofu) require.NoError(t, err) - connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId) + connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId) require.NoError(t, err) sess2 := session.TestSession(t, conn, wrapper, session.ComposedOf{ @@ -311,7 +311,7 @@ func TestWorkerStatusReport(t *testing.T) { tofu2 := session.TestTofu(t) sess2, _, err = repo.ActivateSession(ctx, sess2.PublicId, sess2.Version, tofu2) require.NoError(t, err) - connection2, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId) + connection2, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId) require.NoError(t, err) require.NotEqual(t, connection.PublicId, connection2.PublicId) @@ -348,7 +348,7 @@ func TestWorkerStatusReport(t *testing.T) { tofu := session.TestTofu(t) sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, tofu) require.NoError(t, err) - connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId) + connection, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PublicId) require.NoError(t, err) _, err = repo.CancelSession(ctx, sess.PublicId, sess.Version) require.NoError(t, err) @@ -366,9 +366,9 @@ func TestWorkerStatusReport(t *testing.T) { tofu2 := session.TestTofu(t) sess2, _, err = repo.ActivateSession(ctx, sess2.PublicId, sess2.Version, tofu2) require.NoError(t, err) - connection2, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId) + connection2, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId) require.NoError(t, err) - connection3, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId) + connection3, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PublicId) require.NoError(t, err) _, err = repo.CancelSession(ctx, sess2.PublicId, sess2.Version) require.NoError(t, err) @@ -417,12 +417,10 @@ func TestWorkerStatusReport(t *testing.T) { require.NoError(err) assert.ElementsMatch(tc.want, got) for _, dc := range tc.orphanedConnections { - gotConn, states, err := connRepo.LookupConnection(ctx, dc) + gotConn, err := connRepo.LookupConnection(ctx, dc) require.NoError(err) assert.Equal(session.ConnectionSystemError, session.ClosedReason(gotConn.ClosedReason)) - assert.Equal(2, len(states)) - assert.Nil(states[0].EndTime) - assert.Equal(session.StatusClosed, states[0].Status) + assert.Equal(session.StatusClosed, session.ConnectionStatusFromString(gotConn.Status)) } }) } diff --git a/internal/session/testing.go b/internal/session/testing.go index f9ab69385e..9bd1d60c17 100644 --- a/internal/session/testing.go +++ b/internal/session/testing.go @@ -41,26 +41,9 @@ func TestConnection(t testing.TB, conn *db.DB, sessionId, clientTcpAddr string, err = rw.Create(ctx, c) require.NoError(err) - connectedState, err := NewConnectionState(ctx, c.PublicId, StatusConnected) - require.NoError(err) - err = rw.Create(ctx, connectedState) - require.NoError(err) return c } -// TestConnectionState creates a test connection state for the connectionId in the repository. -func TestConnectionState(t testing.TB, conn *db.DB, connectionId string, state ConnectionStatus) *ConnectionState { - t.Helper() - ctx := context.Background() - require := require.New(t) - rw := db.New(conn) - s, err := NewConnectionState(ctx, connectionId, state) - require.NoError(err) - err = rw.Create(context.Background(), s) - require.NoError(err) - return s -} - // TestState creates a test state for the sessionId in the repository. func TestState(t testing.TB, conn *db.DB, sessionId string, state Status) *State { t.Helper() diff --git a/internal/session/testing_test.go b/internal/session/testing_test.go index 6d3dbdc876..6f6a1e5bea 100644 --- a/internal/session/testing_test.go +++ b/internal/session/testing_test.go @@ -50,29 +50,6 @@ func Test_TestConnection(t *testing.T) { require.NotNil(c) } -func Test_TestConnectionState(t *testing.T) { - assert, require := assert.New(t), require.New(t) - conn, _ := db.TestSetup(t, "postgres") - wrapper := db.TestWrapper(t) - iamRepo := iam.TestRepo(t, conn, wrapper) - s := TestDefaultSession(t, conn, wrapper, iamRepo) - require.NotNil(s) - assert.NotEmpty(s.PublicId) - - c := TestConnection(t, conn, s.PublicId, "0.0.0.0", 22, "0.0.0.0", 2222, "127.0.0.1") - require.NotNil(c) - assert.NotEmpty(c.PublicId) - - cs := TestConnectionState(t, conn, c.PublicId, StatusClosed) - require.NotNil(cs) - - rw := db.New(conn) - var initialState ConnectionState - err := rw.LookupWhere(context.Background(), &initialState, "connection_id = ? and state = ?", []any{cs.ConnectionId, cs.Status}) - require.NoError(err) - assert.NotEmpty(initialState.StartTime) -} - func Test_TestWorker(t *testing.T) { require := require.New(t) conn, _ := db.TestSetup(t, "postgres") diff --git a/internal/tests/helper/testing_helper.go b/internal/tests/helper/testing_helper.go index 3ff17cd8f0..0392eb3f62 100644 --- a/internal/tests/helper/testing_helper.go +++ b/internal/tests/helper/testing_helper.go @@ -155,11 +155,9 @@ func (s *TestSession) ExpectConnectionStateOnController( } for i, conn := range conns { - _, states, err := connectionRepo.LookupConnection(ctx, conn.PublicId, nil) + c, err := connectionRepo.LookupConnection(ctx, conn.PublicId, nil) require.NoError(err) - // Look at the first state in the returned list, which will - // be the most recent state. - actualStates[i] = states[0].Status + actualStates[i] = session.ConnectionStatusFromString(c.Status) } if reflect.DeepEqual(expectStates, actualStates) {