Merge pull request #2046 from hashicorp/irindos-session-transition-trigger

fix(session): Update session state transition trigger
pull/2050/head
Irena Rindos 4 years ago committed by GitHub
commit ca2599ead4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,6 +18,8 @@ Canonical reference for changes, improvements, and bugfixes for Boundary.
* Fix for retrieving sessions that could result in incomplete results when
there is a large number (10k+) of sessions.
[PR](https://github.com/hashicorp/boundary/pull/2049)
* session: update session state trigger to prevent transitions to invalid states ([Issue](https://github.com/hashicorp/boundary/issues/2040),
[PR](https://github.com/hashicorp/boundary/pull/2046))
## 0.7.6 (2022/03/15)

@ -426,7 +426,7 @@ begin;
update on session_state
for each row execute procedure immutable_columns('session_id', 'state', 'start_time', 'previous_end_time');
-- Replaced in 28/02_prior_session_trigger.up.sql
create or replace function
insert_session_state()
returns trigger

@ -0,0 +1,114 @@
begin;
-- Drop prior session state trigger; to be replaced with logic added to insert_session_state()
drop trigger update_session_state on session_state;
drop function update_prior_session_state();
-- Remove invalid session transitions lingering in the DB
-- Create a temp table to classify session states by number of transitions
create temp table state_counts as
select count(*), session_id from session_state group by session_id;
-- Remove bad states for session ids with 4 states. Remove states added after 'terminated' to remove
-- invalid transitions that occurred after the session terminated
-- Ex: PATC, PTAC, PCTA, PTCA, pruned to valid transitions PAT, PT, PCT, PT
do $$
declare
target record;
begin
for target in select session_id from state_counts where count = 4 loop
-- If the last state is not terminated for this session...
if (select state from session_state where session_id in
(select session_id from state_counts where count = 4)
and session_id=target.session_id order by start_time desc limit 1)
!= 'terminated' then
-- Prune invalid states after the session terminated
delete from session_state where session_id = target.session_id and previous_end_time >(
select previous_end_time from session_state where session_id=target.session_id and state='terminated');
-- Then remove terminated end time
update session_state set end_time=NULL where session_id=target.session_id and state='terminated';
end if;
end loop;
end; $$;
-- Remove bad states for session ids with 3 states, similar to the above
-- Difference from above is the need to check if terminated exists in the set of states
-- Ex: PTA, PTC -> pruned to PT, PT
-- Additional check for state PCA, pruned to PC
-- Valid transitions like PAT, PCT, and in progress sessions like PAC will be ignored
do $$
declare
target record;
begin
for target in select session_id from state_counts where count = 3 loop
-- If the last state is not terminated for this session...
if (select state from session_state where session_id in
(select session_id from state_counts where count = 3)
and session_id=target.session_id order by start_time desc limit 1)
!= 'terminated' then
-- See if terminated appears; if so, prune back to it
if exists(select * from session_state where session_id=target.session_id and state='terminated')then
--Then we find the terminated record and timestamp and delete those that came before
delete from session_state where session_id = target.session_id and previous_end_time >(
select previous_end_time from session_state where session_id=target.session_id and state='terminated');
-- Then remove terminated end time
update session_state set end_time=NULL where session_id=target.session_id and state='terminated';
end if;
end if;
-- Check for PCA case; if last state is not cancelled but canceling appears in the states
if (select state from session_state where session_id in
(select session_id from state_counts where count = 3)
and session_id=target.session_id order by start_time desc limit 1)
!= 'canceling' then
-- See if canceling appears; if so, prune back to it
if exists(select * from session_state where session_id=target.session_id and state='canceling')then
--Then we find the canceling record and timestamp and delete those that came before
delete from session_state where session_id = target.session_id and previous_end_time >(
select previous_end_time from session_state where session_id=target.session_id and state='canceling');
-- Then remove canceling end time
update session_state set end_time=NULL where session_id=target.session_id and state='canceling';
end if;
end if;
end loop;
end; $$;
-- Replaces trigger from 0/50_session.up.sql
-- Update insert session state transition trigger
drop trigger insert_session_state on session_state;
drop function insert_session_state();
create function
insert_session_state()
returns trigger
as $$
declare
old_col_state text;
begin
update session_state
set end_time = now()
where (session_id = new.session_id
and end_time is null) returning state into old_col_state;
new.prior_state= old_col_state;
if not found then
new.previous_end_time = null;
new.start_time = now();
new.end_time = null;
new.prior_state='pending';
return new;
end if;
new.previous_end_time = now();
new.start_time = now();
new.end_time = null;
return new;
end;
$$ language plpgsql;
create trigger insert_session_state before insert on session_state
for each row execute procedure insert_session_state();
commit;

@ -0,0 +1,151 @@
package oss_test
import (
"context"
"testing"
"github.com/hashicorp/boundary/internal/authtoken"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/common"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/host/static"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/session"
"github.com/hashicorp/boundary/internal/target"
"github.com/hashicorp/boundary/internal/target/tcp"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/hashicorp/boundary/testing/dbtest"
"github.com/stretchr/testify/require"
)
func TestMigrations_SessionStateTrigger(t *testing.T) {
t.Parallel()
require := require.New(t)
const (
priorMigration = 27002
currentMigration = 28002
)
dialect := dbtest.Postgres
ctx := context.Background()
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
require.NoError(err)
t.Cleanup(func() {
require.NoError(c())
})
d, err := common.SqlOpen(dialect, u)
require.NoError(err)
// migration to the prior migration (before the one we want to test)
m, err := schema.NewManager(ctx, schema.Dialect(dialect), d, schema.WithEditions(
schema.TestCreatePartialEditions(schema.Dialect(dialect), schema.PartialEditions{"oss": priorMigration}),
))
require.NoError(err)
require.NoError(m.ApplyMigrations(ctx))
state, err := m.CurrentState(ctx)
require.NoError(err)
want := &schema.State{
Initialized: true,
Editions: []schema.EditionState{
{
Name: "oss",
BinarySchemaVersion: priorMigration,
DatabaseSchemaVersion: priorMigration,
DatabaseSchemaState: schema.Equal,
},
},
}
require.Equal(want, state)
// Seed the database with test data
dbType, err := db.StringToDbType(dialect)
require.NoError(err)
conn, err := db.Open(dbType, u)
require.NoError(err)
rw := db.New(conn)
wrapper := db.TestWrapper(t)
org, prj := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
require.NotNil(prj)
hc := static.TestCatalogs(t, conn, prj.GetPublicId(), 1)[0]
hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0]
h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0]
static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h})
tar := tcp.TestTarget(ctx, t, conn, prj.GetPublicId(), "test", target.WithHostSources([]string{hs.GetPublicId()}))
kmsCache := kms.TestKms(t, conn, wrapper)
serverId := "worker"
tofu := session.TestTofu(t)
session.TestWorker(t, conn, wrapper, session.WithServerId(serverId))
at := authtoken.TestAuthToken(t, conn, kmsCache, org.GetPublicId())
uId := at.GetIamUserId()
sess := session.TestSession(t, conn, wrapper, session.ComposedOf{
UserId: uId,
HostId: h.GetPublicId(),
TargetId: tar.GetPublicId(),
HostSetId: hs.GetPublicId(),
AuthTokenId: at.GetPublicId(),
ScopeId: prj.GetPublicId(),
Endpoint: "tcp://127.0.0.1:22",
ConnectionLimit: 1,
})
sessionRepo, err := session.NewRepository(rw, rw, kmsCache)
require.NoError(err)
// Make and transition a valid session through pending, active, canceling, and terminated
// Unfortunately, could not recreate an invalid session P-A-T-C using the session repo
_, _, err = sessionRepo.ActivateSession(ctx, sess.PublicId, sess.Version, serverId, resource.Worker.String(), tofu)
require.NoError(err)
connection := session.TestConnection(t, conn, sess.PublicId, "127.0.0.1", 22,
"127.0.0.2", 23, "127.0.0.1")
session.TestConnectionState(t, conn, connection.PublicId, session.StatusConnected)
session.TestConnectionState(t, conn, connection.PublicId, session.StatusClosed)
_, err = sessionRepo.CancelSession(ctx, sess.PublicId, sess.Version+1)
require.NoError(err)
sessionRepo.TerminateCompletedSessions(ctx)
repoSessions, err := sessionRepo.ListSessions(ctx)
require.NoError(err)
var numStates int
for _, s := range repoSessions {
numStates = len(s.States)
}
require.Equal(4, numStates)
// now we're ready for the migration we want to test.
m, err = schema.NewManager(ctx, schema.Dialect(dialect), d, schema.WithEditions(
schema.TestCreatePartialEditions(schema.Dialect(dialect), schema.PartialEditions{"oss": currentMigration}),
))
require.NoError(err)
require.NoError(m.ApplyMigrations(ctx))
state, err = m.CurrentState(ctx)
require.NoError(err)
want = &schema.State{
Initialized: true,
Editions: []schema.EditionState{
{
Name: "oss",
BinarySchemaVersion: currentMigration,
DatabaseSchemaVersion: currentMigration,
DatabaseSchemaState: schema.Equal,
},
},
}
require.Equal(want, state)
// Check that we haven't removed a state
repoSessions, err = sessionRepo.ListSessions(ctx)
require.NoError(err)
for _, s := range repoSessions {
numStates = len(s.States)
}
require.Equal(4, numStates)
}

@ -564,6 +564,103 @@ func TestRepository_updateState(t *testing.T) {
}
}
func TestRepository_transitionState(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(rw, rw, kms)
require.NoError(t, err)
tofu := TestTofu(t)
srv := TestWorker(t, conn, wrapper)
tests := []struct {
name string
session *Session
states []Status
wantErr []bool
wantIsError errors.Code
}{
{
name: "full valid state transition",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
states: []Status{
StatusPending, StatusActive, StatusCanceling, StatusTerminated,
},
wantErr: []bool{false, false, false, false},
},
{
name: "partial valid state transition- 1",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
states: []Status{
StatusPending, StatusActive, StatusTerminated,
},
wantErr: []bool{false, false, false, false},
},
{
name: "partial valid state transition- 2",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
states: []Status{
StatusPending, StatusCanceling, StatusTerminated,
},
wantErr: []bool{false, false, false},
},
{
name: "invalid state transition - 1",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
states: []Status{
StatusPending, StatusCanceling, StatusTerminated, StatusActive,
},
wantErr: []bool{false, false, false, true},
},
{
name: "invalid state transition - 2",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
states: []Status{
StatusPending, StatusCanceling, StatusActive,
},
wantErr: []bool{false, false, true},
},
{
name: "invalid state transition - 3",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
states: []Status{
StatusPending, StatusTerminated, StatusActive,
},
wantErr: []bool{false, false, true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
id := tt.session.PublicId
version := tt.session.Version
for i, status := range tt.states {
var s *Session
var ss []*State
var err error
if status == StatusActive {
s, ss, err = repo.ActivateSession(context.Background(), id, version, srv.PrivateId, srv.Type, tofu)
} else {
s, ss, err = repo.updateState(context.Background(), id, version, status)
}
if tt.wantErr[i] {
require.Error(err)
assert.Truef(errors.Match(errors.T(tt.wantIsError), err), "unexpected error %s", err.Error())
return
}
require.NoError(err)
require.NotNil(s)
require.NotNil(ss)
assert.Equal(status, ss[0].Status)
version = s.Version
}
})
}
}
func TestRepository_TerminateCompletedSessions(t *testing.T) {
t.Parallel()
ctx := context.Background()

Loading…
Cancel
Save