diff --git a/CHANGELOG.md b/CHANGELOG.md index 846e7f187c..203102c6aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ Canonical reference for changes, improvements, and bugfixes for Boundary. * Fix for retrieving sessions that could result in incomplete results when there is a large number (10k+) of sessions. [PR](https://github.com/hashicorp/boundary/pull/2049) +* session: update session state trigger to prevent transitions to invalid states ([Issue](https://github.com/hashicorp/boundary/issues/2040), + [PR](https://github.com/hashicorp/boundary/pull/2046)) ## 0.7.6 (2022/03/15) diff --git a/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql b/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql index 0eba0883c1..1e16244ff0 100644 --- a/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql +++ b/internal/db/schema/migrations/oss/postgres/0/50_session.up.sql @@ -426,7 +426,7 @@ begin; update on session_state for each row execute procedure immutable_columns('session_id', 'state', 'start_time', 'previous_end_time'); - +-- Replaced in 28/02_prior_session_trigger.up.sql create or replace function insert_session_state() returns trigger diff --git a/internal/db/schema/migrations/oss/postgres/28/02_prior_session_trigger.up.sql b/internal/db/schema/migrations/oss/postgres/28/02_prior_session_trigger.up.sql new file mode 100644 index 0000000000..00a62f8aa1 --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/28/02_prior_session_trigger.up.sql @@ -0,0 +1,114 @@ +begin; + +-- Drop prior session state trigger; to be replaced with logic added to insert_session_state() +drop trigger update_session_state on session_state; +drop function update_prior_session_state(); + +-- Remove invalid session transitions lingering in the DB + +-- Create a temp table to classify session states by number of transitions +create temp table state_counts as +select count(*), session_id from session_state group by session_id; + +-- Remove bad states for session ids with 4 states. Remove states added after 'terminated' to remove +-- invalid transitions that occurred after the session terminated +-- Ex: PATC, PTAC, PCTA, PTCA, pruned to valid transitions PAT, PT, PCT, PT +do $$ +declare + target record; +begin + for target in select session_id from state_counts where count = 4 loop + -- If the last state is not terminated for this session... + if (select state from session_state where session_id in + (select session_id from state_counts where count = 4) + and session_id=target.session_id order by start_time desc limit 1) + != 'terminated' then + -- Prune invalid states after the session terminated + delete from session_state where session_id = target.session_id and previous_end_time >( + select previous_end_time from session_state where session_id=target.session_id and state='terminated'); + -- Then remove terminated end time + update session_state set end_time=NULL where session_id=target.session_id and state='terminated'; + end if; + end loop; +end; $$; + +-- Remove bad states for session ids with 3 states, similar to the above +-- Difference from above is the need to check if terminated exists in the set of states +-- Ex: PTA, PTC -> pruned to PT, PT +-- Additional check for state PCA, pruned to PC +-- Valid transitions like PAT, PCT, and in progress sessions like PAC will be ignored +do $$ +declare + target record; +begin + for target in select session_id from state_counts where count = 3 loop + -- If the last state is not terminated for this session... + if (select state from session_state where session_id in + (select session_id from state_counts where count = 3) + and session_id=target.session_id order by start_time desc limit 1) + != 'terminated' then + -- See if terminated appears; if so, prune back to it + if exists(select * from session_state where session_id=target.session_id and state='terminated')then + --Then we find the terminated record and timestamp and delete those that came before + delete from session_state where session_id = target.session_id and previous_end_time >( + select previous_end_time from session_state where session_id=target.session_id and state='terminated'); + -- Then remove terminated end time + update session_state set end_time=NULL where session_id=target.session_id and state='terminated'; + end if; + end if; + -- Check for PCA case; if last state is not cancelled but canceling appears in the states + if (select state from session_state where session_id in + (select session_id from state_counts where count = 3) + and session_id=target.session_id order by start_time desc limit 1) + != 'canceling' then + -- See if canceling appears; if so, prune back to it + if exists(select * from session_state where session_id=target.session_id and state='canceling')then + --Then we find the canceling record and timestamp and delete those that came before + delete from session_state where session_id = target.session_id and previous_end_time >( + select previous_end_time from session_state where session_id=target.session_id and state='canceling'); + -- Then remove canceling end time + update session_state set end_time=NULL where session_id=target.session_id and state='canceling'; + end if; + end if; + end loop; +end; $$; + +-- Replaces trigger from 0/50_session.up.sql +-- Update insert session state transition trigger +drop trigger insert_session_state on session_state; +drop function insert_session_state(); + +create function + insert_session_state() + returns trigger +as $$ +declare + old_col_state text; +begin + update session_state + set end_time = now() + where (session_id = new.session_id + and end_time is null) returning state into old_col_state; + new.prior_state= old_col_state; + + if not found then + new.previous_end_time = null; + new.start_time = now(); + new.end_time = null; + new.prior_state='pending'; + return new; + end if; + + new.previous_end_time = now(); + new.start_time = now(); + new.end_time = null; + + return new; + +end; +$$ language plpgsql; + +create trigger insert_session_state before insert on session_state + for each row execute procedure insert_session_state(); + +commit; \ No newline at end of file diff --git a/internal/db/schema/migrations/oss/postgres_28_02_test.go b/internal/db/schema/migrations/oss/postgres_28_02_test.go new file mode 100644 index 0000000000..d10967cd76 --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres_28_02_test.go @@ -0,0 +1,151 @@ +package oss_test + +import ( + "context" + "testing" + + "github.com/hashicorp/boundary/internal/authtoken" + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/common" + "github.com/hashicorp/boundary/internal/db/schema" + "github.com/hashicorp/boundary/internal/host/static" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/session" + "github.com/hashicorp/boundary/internal/target" + "github.com/hashicorp/boundary/internal/target/tcp" + "github.com/hashicorp/boundary/internal/types/resource" + "github.com/hashicorp/boundary/testing/dbtest" + "github.com/stretchr/testify/require" +) + +func TestMigrations_SessionStateTrigger(t *testing.T) { + t.Parallel() + require := require.New(t) + + const ( + priorMigration = 27002 + currentMigration = 28002 + ) + dialect := dbtest.Postgres + ctx := context.Background() + + c, u, _, err := dbtest.StartUsingTemplate(dialect, dbtest.WithTemplate(dbtest.Template1)) + require.NoError(err) + t.Cleanup(func() { + require.NoError(c()) + }) + d, err := common.SqlOpen(dialect, u) + require.NoError(err) + + // migration to the prior migration (before the one we want to test) + m, err := schema.NewManager(ctx, schema.Dialect(dialect), d, schema.WithEditions( + schema.TestCreatePartialEditions(schema.Dialect(dialect), schema.PartialEditions{"oss": priorMigration}), + )) + require.NoError(err) + + require.NoError(m.ApplyMigrations(ctx)) + state, err := m.CurrentState(ctx) + require.NoError(err) + want := &schema.State{ + Initialized: true, + Editions: []schema.EditionState{ + { + Name: "oss", + BinarySchemaVersion: priorMigration, + DatabaseSchemaVersion: priorMigration, + DatabaseSchemaState: schema.Equal, + }, + }, + } + require.Equal(want, state) + + // Seed the database with test data + dbType, err := db.StringToDbType(dialect) + require.NoError(err) + + conn, err := db.Open(dbType, u) + require.NoError(err) + + rw := db.New(conn) + wrapper := db.TestWrapper(t) + + org, prj := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) + require.NotNil(prj) + + hc := static.TestCatalogs(t, conn, prj.GetPublicId(), 1)[0] + hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0] + h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0] + static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h}) + tar := tcp.TestTarget(ctx, t, conn, prj.GetPublicId(), "test", target.WithHostSources([]string{hs.GetPublicId()})) + kmsCache := kms.TestKms(t, conn, wrapper) + + serverId := "worker" + tofu := session.TestTofu(t) + session.TestWorker(t, conn, wrapper, session.WithServerId(serverId)) + at := authtoken.TestAuthToken(t, conn, kmsCache, org.GetPublicId()) + uId := at.GetIamUserId() + sess := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 1, + }) + + sessionRepo, err := session.NewRepository(rw, rw, kmsCache) + require.NoError(err) + + // Make and transition a valid session through pending, active, canceling, and terminated + // Unfortunately, could not recreate an invalid session P-A-T-C using the session repo + _, _, err = sessionRepo.ActivateSession(ctx, sess.PublicId, sess.Version, serverId, resource.Worker.String(), tofu) + require.NoError(err) + connection := session.TestConnection(t, conn, sess.PublicId, "127.0.0.1", 22, + "127.0.0.2", 23, "127.0.0.1") + session.TestConnectionState(t, conn, connection.PublicId, session.StatusConnected) + session.TestConnectionState(t, conn, connection.PublicId, session.StatusClosed) + _, err = sessionRepo.CancelSession(ctx, sess.PublicId, sess.Version+1) + require.NoError(err) + sessionRepo.TerminateCompletedSessions(ctx) + + repoSessions, err := sessionRepo.ListSessions(ctx) + require.NoError(err) + var numStates int + for _, s := range repoSessions { + numStates = len(s.States) + } + require.Equal(4, numStates) + + // now we're ready for the migration we want to test. + m, err = schema.NewManager(ctx, schema.Dialect(dialect), d, schema.WithEditions( + schema.TestCreatePartialEditions(schema.Dialect(dialect), schema.PartialEditions{"oss": currentMigration}), + )) + require.NoError(err) + + require.NoError(m.ApplyMigrations(ctx)) + state, err = m.CurrentState(ctx) + require.NoError(err) + want = &schema.State{ + Initialized: true, + Editions: []schema.EditionState{ + { + Name: "oss", + BinarySchemaVersion: currentMigration, + DatabaseSchemaVersion: currentMigration, + DatabaseSchemaState: schema.Equal, + }, + }, + } + require.Equal(want, state) + + // Check that we haven't removed a state + repoSessions, err = sessionRepo.ListSessions(ctx) + require.NoError(err) + for _, s := range repoSessions { + numStates = len(s.States) + } + require.Equal(4, numStates) +} diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index 8f7d1b7384..cc9a770c65 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -564,6 +564,103 @@ func TestRepository_updateState(t *testing.T) { } } +func TestRepository_transitionState(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(t, err) + tofu := TestTofu(t) + srv := TestWorker(t, conn, wrapper) + + tests := []struct { + name string + session *Session + states []Status + wantErr []bool + wantIsError errors.Code + }{ + { + name: "full valid state transition", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + states: []Status{ + StatusPending, StatusActive, StatusCanceling, StatusTerminated, + }, + wantErr: []bool{false, false, false, false}, + }, + { + name: "partial valid state transition- 1", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + states: []Status{ + StatusPending, StatusActive, StatusTerminated, + }, + wantErr: []bool{false, false, false, false}, + }, + { + name: "partial valid state transition- 2", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + states: []Status{ + StatusPending, StatusCanceling, StatusTerminated, + }, + wantErr: []bool{false, false, false}, + }, + { + name: "invalid state transition - 1", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + states: []Status{ + StatusPending, StatusCanceling, StatusTerminated, StatusActive, + }, + wantErr: []bool{false, false, false, true}, + }, + { + name: "invalid state transition - 2", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + states: []Status{ + StatusPending, StatusCanceling, StatusActive, + }, + wantErr: []bool{false, false, true}, + }, + { + name: "invalid state transition - 3", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + states: []Status{ + StatusPending, StatusTerminated, StatusActive, + }, + wantErr: []bool{false, false, true}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + id := tt.session.PublicId + version := tt.session.Version + for i, status := range tt.states { + var s *Session + var ss []*State + var err error + if status == StatusActive { + s, ss, err = repo.ActivateSession(context.Background(), id, version, srv.PrivateId, srv.Type, tofu) + } else { + s, ss, err = repo.updateState(context.Background(), id, version, status) + } + if tt.wantErr[i] { + require.Error(err) + assert.Truef(errors.Match(errors.T(tt.wantIsError), err), "unexpected error %s", err.Error()) + return + } + require.NoError(err) + require.NotNil(s) + require.NotNil(ss) + assert.Equal(status, ss[0].Status) + version = s.Version + } + }) + } +} + func TestRepository_TerminateCompletedSessions(t *testing.T) { t.Parallel() ctx := context.Background()