mirror of https://github.com/hashicorp/boundary
Merge pull request #2046 from hashicorp/irindos-session-transition-trigger
fix(session): Update session state transition triggerpull/2050/head
commit
ca2599ead4
@ -0,0 +1,114 @@
|
||||
begin;
|
||||
|
||||
-- Drop prior session state trigger; to be replaced with logic added to insert_session_state()
|
||||
drop trigger update_session_state on session_state;
|
||||
drop function update_prior_session_state();
|
||||
|
||||
-- Remove invalid session transitions lingering in the DB
|
||||
|
||||
-- Create a temp table to classify session states by number of transitions
|
||||
create temp table state_counts as
|
||||
select count(*), session_id from session_state group by session_id;
|
||||
|
||||
-- Remove bad states for session ids with 4 states. Remove states added after 'terminated' to remove
|
||||
-- invalid transitions that occurred after the session terminated
|
||||
-- Ex: PATC, PTAC, PCTA, PTCA, pruned to valid transitions PAT, PT, PCT, PT
|
||||
do $$
|
||||
declare
|
||||
target record;
|
||||
begin
|
||||
for target in select session_id from state_counts where count = 4 loop
|
||||
-- If the last state is not terminated for this session...
|
||||
if (select state from session_state where session_id in
|
||||
(select session_id from state_counts where count = 4)
|
||||
and session_id=target.session_id order by start_time desc limit 1)
|
||||
!= 'terminated' then
|
||||
-- Prune invalid states after the session terminated
|
||||
delete from session_state where session_id = target.session_id and previous_end_time >(
|
||||
select previous_end_time from session_state where session_id=target.session_id and state='terminated');
|
||||
-- Then remove terminated end time
|
||||
update session_state set end_time=NULL where session_id=target.session_id and state='terminated';
|
||||
end if;
|
||||
end loop;
|
||||
end; $$;
|
||||
|
||||
-- Remove bad states for session ids with 3 states, similar to the above
|
||||
-- Difference from above is the need to check if terminated exists in the set of states
|
||||
-- Ex: PTA, PTC -> pruned to PT, PT
|
||||
-- Additional check for state PCA, pruned to PC
|
||||
-- Valid transitions like PAT, PCT, and in progress sessions like PAC will be ignored
|
||||
do $$
|
||||
declare
|
||||
target record;
|
||||
begin
|
||||
for target in select session_id from state_counts where count = 3 loop
|
||||
-- If the last state is not terminated for this session...
|
||||
if (select state from session_state where session_id in
|
||||
(select session_id from state_counts where count = 3)
|
||||
and session_id=target.session_id order by start_time desc limit 1)
|
||||
!= 'terminated' then
|
||||
-- See if terminated appears; if so, prune back to it
|
||||
if exists(select * from session_state where session_id=target.session_id and state='terminated')then
|
||||
--Then we find the terminated record and timestamp and delete those that came before
|
||||
delete from session_state where session_id = target.session_id and previous_end_time >(
|
||||
select previous_end_time from session_state where session_id=target.session_id and state='terminated');
|
||||
-- Then remove terminated end time
|
||||
update session_state set end_time=NULL where session_id=target.session_id and state='terminated';
|
||||
end if;
|
||||
end if;
|
||||
-- Check for PCA case; if last state is not cancelled but canceling appears in the states
|
||||
if (select state from session_state where session_id in
|
||||
(select session_id from state_counts where count = 3)
|
||||
and session_id=target.session_id order by start_time desc limit 1)
|
||||
!= 'canceling' then
|
||||
-- See if canceling appears; if so, prune back to it
|
||||
if exists(select * from session_state where session_id=target.session_id and state='canceling')then
|
||||
--Then we find the canceling record and timestamp and delete those that came before
|
||||
delete from session_state where session_id = target.session_id and previous_end_time >(
|
||||
select previous_end_time from session_state where session_id=target.session_id and state='canceling');
|
||||
-- Then remove canceling end time
|
||||
update session_state set end_time=NULL where session_id=target.session_id and state='canceling';
|
||||
end if;
|
||||
end if;
|
||||
end loop;
|
||||
end; $$;
|
||||
|
||||
-- Replaces trigger from 0/50_session.up.sql
|
||||
-- Update insert session state transition trigger
|
||||
drop trigger insert_session_state on session_state;
|
||||
drop function insert_session_state();
|
||||
|
||||
create function
|
||||
insert_session_state()
|
||||
returns trigger
|
||||
as $$
|
||||
declare
|
||||
old_col_state text;
|
||||
begin
|
||||
update session_state
|
||||
set end_time = now()
|
||||
where (session_id = new.session_id
|
||||
and end_time is null) returning state into old_col_state;
|
||||
new.prior_state= old_col_state;
|
||||
|
||||
if not found then
|
||||
new.previous_end_time = null;
|
||||
new.start_time = now();
|
||||
new.end_time = null;
|
||||
new.prior_state='pending';
|
||||
return new;
|
||||
end if;
|
||||
|
||||
new.previous_end_time = now();
|
||||
new.start_time = now();
|
||||
new.end_time = null;
|
||||
|
||||
return new;
|
||||
|
||||
end;
|
||||
$$ language plpgsql;
|
||||
|
||||
create trigger insert_session_state before insert on session_state
|
||||
for each row execute procedure insert_session_state();
|
||||
|
||||
commit;
|
||||
@ -0,0 +1,151 @@
|
||||
package oss_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/authtoken"
|
||||
"github.com/hashicorp/boundary/internal/db"
|
||||
"github.com/hashicorp/boundary/internal/db/common"
|
||||
"github.com/hashicorp/boundary/internal/db/schema"
|
||||
"github.com/hashicorp/boundary/internal/host/static"
|
||||
"github.com/hashicorp/boundary/internal/iam"
|
||||
"github.com/hashicorp/boundary/internal/kms"
|
||||
"github.com/hashicorp/boundary/internal/session"
|
||||
"github.com/hashicorp/boundary/internal/target"
|
||||
"github.com/hashicorp/boundary/internal/target/tcp"
|
||||
"github.com/hashicorp/boundary/internal/types/resource"
|
||||
"github.com/hashicorp/boundary/testing/dbtest"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigrations_SessionStateTrigger(t *testing.T) {
|
||||
t.Parallel()
|
||||
require := require.New(t)
|
||||
|
||||
const (
|
||||
priorMigration = 27002
|
||||
currentMigration = 28002
|
||||
)
|
||||
dialect := dbtest.Postgres
|
||||
ctx := context.Background()
|
||||
|
||||
c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1))
|
||||
require.NoError(err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(c())
|
||||
})
|
||||
d, err := common.SqlOpen(dialect, u)
|
||||
require.NoError(err)
|
||||
|
||||
// migration to the prior migration (before the one we want to test)
|
||||
m, err := schema.NewManager(ctx, schema.Dialect(dialect), d, schema.WithEditions(
|
||||
schema.TestCreatePartialEditions(schema.Dialect(dialect), schema.PartialEditions{"oss": priorMigration}),
|
||||
))
|
||||
require.NoError(err)
|
||||
|
||||
require.NoError(m.ApplyMigrations(ctx))
|
||||
state, err := m.CurrentState(ctx)
|
||||
require.NoError(err)
|
||||
want := &schema.State{
|
||||
Initialized: true,
|
||||
Editions: []schema.EditionState{
|
||||
{
|
||||
Name: "oss",
|
||||
BinarySchemaVersion: priorMigration,
|
||||
DatabaseSchemaVersion: priorMigration,
|
||||
DatabaseSchemaState: schema.Equal,
|
||||
},
|
||||
},
|
||||
}
|
||||
require.Equal(want, state)
|
||||
|
||||
// Seed the database with test data
|
||||
dbType, err := db.StringToDbType(dialect)
|
||||
require.NoError(err)
|
||||
|
||||
conn, err := db.Open(dbType, u)
|
||||
require.NoError(err)
|
||||
|
||||
rw := db.New(conn)
|
||||
wrapper := db.TestWrapper(t)
|
||||
|
||||
org, prj := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
|
||||
require.NotNil(prj)
|
||||
|
||||
hc := static.TestCatalogs(t, conn, prj.GetPublicId(), 1)[0]
|
||||
hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0]
|
||||
h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0]
|
||||
static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h})
|
||||
tar := tcp.TestTarget(ctx, t, conn, prj.GetPublicId(), "test", target.WithHostSources([]string{hs.GetPublicId()}))
|
||||
kmsCache := kms.TestKms(t, conn, wrapper)
|
||||
|
||||
serverId := "worker"
|
||||
tofu := session.TestTofu(t)
|
||||
session.TestWorker(t, conn, wrapper, session.WithServerId(serverId))
|
||||
at := authtoken.TestAuthToken(t, conn, kmsCache, org.GetPublicId())
|
||||
uId := at.GetIamUserId()
|
||||
sess := session.TestSession(t, conn, wrapper, session.ComposedOf{
|
||||
UserId: uId,
|
||||
HostId: h.GetPublicId(),
|
||||
TargetId: tar.GetPublicId(),
|
||||
HostSetId: hs.GetPublicId(),
|
||||
AuthTokenId: at.GetPublicId(),
|
||||
ScopeId: prj.GetPublicId(),
|
||||
Endpoint: "tcp://127.0.0.1:22",
|
||||
ConnectionLimit: 1,
|
||||
})
|
||||
|
||||
sessionRepo, err := session.NewRepository(rw, rw, kmsCache)
|
||||
require.NoError(err)
|
||||
|
||||
// Make and transition a valid session through pending, active, canceling, and terminated
|
||||
// Unfortunately, could not recreate an invalid session P-A-T-C using the session repo
|
||||
_, _, err = sessionRepo.ActivateSession(ctx, sess.PublicId, sess.Version, serverId, resource.Worker.String(), tofu)
|
||||
require.NoError(err)
|
||||
connection := session.TestConnection(t, conn, sess.PublicId, "127.0.0.1", 22,
|
||||
"127.0.0.2", 23, "127.0.0.1")
|
||||
session.TestConnectionState(t, conn, connection.PublicId, session.StatusConnected)
|
||||
session.TestConnectionState(t, conn, connection.PublicId, session.StatusClosed)
|
||||
_, err = sessionRepo.CancelSession(ctx, sess.PublicId, sess.Version+1)
|
||||
require.NoError(err)
|
||||
sessionRepo.TerminateCompletedSessions(ctx)
|
||||
|
||||
repoSessions, err := sessionRepo.ListSessions(ctx)
|
||||
require.NoError(err)
|
||||
var numStates int
|
||||
for _, s := range repoSessions {
|
||||
numStates = len(s.States)
|
||||
}
|
||||
require.Equal(4, numStates)
|
||||
|
||||
// now we're ready for the migration we want to test.
|
||||
m, err = schema.NewManager(ctx, schema.Dialect(dialect), d, schema.WithEditions(
|
||||
schema.TestCreatePartialEditions(schema.Dialect(dialect), schema.PartialEditions{"oss": currentMigration}),
|
||||
))
|
||||
require.NoError(err)
|
||||
|
||||
require.NoError(m.ApplyMigrations(ctx))
|
||||
state, err = m.CurrentState(ctx)
|
||||
require.NoError(err)
|
||||
want = &schema.State{
|
||||
Initialized: true,
|
||||
Editions: []schema.EditionState{
|
||||
{
|
||||
Name: "oss",
|
||||
BinarySchemaVersion: currentMigration,
|
||||
DatabaseSchemaVersion: currentMigration,
|
||||
DatabaseSchemaState: schema.Equal,
|
||||
},
|
||||
},
|
||||
}
|
||||
require.Equal(want, state)
|
||||
|
||||
// Check that we haven't removed a state
|
||||
repoSessions, err = sessionRepo.ListSessions(ctx)
|
||||
require.NoError(err)
|
||||
for _, s := range repoSessions {
|
||||
numStates = len(s.States)
|
||||
}
|
||||
require.Equal(4, numStates)
|
||||
}
|
||||
Loading…
Reference in new issue