From 57da9f918e18cba56fc4feb734515683abd12a50 Mon Sep 17 00:00:00 2001 From: Jim Date: Tue, 22 Sep 2020 14:36:20 -0400 Subject: [PATCH] cancel a session when one of its FKs is set to null (#406) --- internal/db/migrations/postgres.gen.go | 65 +++++- .../db/migrations/postgres/50_session.up.sql | 65 +++++- internal/session/repository_session_test.go | 190 ++++++++++++++++++ 3 files changed, 318 insertions(+), 2 deletions(-) diff --git a/internal/db/migrations/postgres.gen.go b/internal/db/migrations/postgres.gen.go index a625635afa..c510616a7f 100644 --- a/internal/db/migrations/postgres.gen.go +++ b/internal/db/migrations/postgres.gen.go @@ -3494,9 +3494,11 @@ begin; update on session for each row execute procedure immutable_columns('public_id', 'certificate', 'expiration_time', 'connection_limit', 'create_time', 'endpoint'); + -- session table has some cascades of FK to null, so we need to be careful + -- which columns trigger an update of the version column create trigger update_version_column - after update on session + after update of version, termination_reason, key_id, tofu_token, server_id, server_type on session for each row execute procedure update_version_column(); create trigger @@ -3606,6 +3608,67 @@ begin; for each row execute procedure update_session_state_on_termination_reason(); + -- cancel_session will insert a cancel state for the session, if there's isn't + -- a canceled state already. It's used by cancel_session_with_null_fk. + create or replace function + cancel_session(in sessionId text) returns void + as $$ + declare + rows_affected numeric; + begin + insert into session_state(session_id, state) + select + sessionId::text, 'canceling' + from + session s + where + s.public_id = sessionId::text and + s.public_id not in ( + select + session_id + from + session_state + where + session_id = sessionId::text and + state = 'canceling' + ) limit 1; + get diagnostics rows_affected = row_count; + if rows_affected > 1 then + raise exception 'cancel session: more than one row affected: %', rows_affected; + end if; + end; + $$ language plpgsql; + + -- cancel_session_with_null_fk is intended to be a before update trigger that + -- sets the session's state to cancel if a FK is set to null. + create or replace function + cancel_session_with_null_fk() + returns trigger + as $$ + begin + case + when new.user_id is null then + perform cancel_session(new.public_id); + when new.host_id is null then + perform cancel_session(new.public_id); + when new.target_id is null then + perform cancel_session(new.public_id); + when new.host_set_id is null then + perform cancel_session(new.public_id); + when new.auth_token_id is null then + perform cancel_session(new.public_id); + when new.scope_id is null then + perform cancel_session(new.public_id); + end case; + return new; + end; + $$ language plpgsql; + + create trigger + cancel_session_with_null_fk + before update of user_id, host_id, target_id, host_set_id, auth_token_id, scope_id on session + for each row execute procedure cancel_session_with_null_fk(); + create table session_state_enm ( name text primary key check ( diff --git a/internal/db/migrations/postgres/50_session.up.sql b/internal/db/migrations/postgres/50_session.up.sql index 2845f2be07..be0f54e630 100644 --- a/internal/db/migrations/postgres/50_session.up.sql +++ b/internal/db/migrations/postgres/50_session.up.sql @@ -158,9 +158,11 @@ begin; update on session for each row execute procedure immutable_columns('public_id', 'certificate', 'expiration_time', 'connection_limit', 'create_time', 'endpoint'); + -- session table has some cascades of FK to null, so we need to be careful + -- which columns trigger an update of the version column create trigger update_version_column - after update on session + after update of version, termination_reason, key_id, tofu_token, server_id, server_type on session for each row execute procedure update_version_column(); create trigger @@ -270,6 +272,67 @@ begin; for each row execute procedure update_session_state_on_termination_reason(); + -- cancel_session will insert a cancel state for the session, if there's isn't + -- a canceled state already. It's used by cancel_session_with_null_fk. + create or replace function + cancel_session(in sessionId text) returns void + as $$ + declare + rows_affected numeric; + begin + insert into session_state(session_id, state) + select + sessionId::text, 'canceling' + from + session s + where + s.public_id = sessionId::text and + s.public_id not in ( + select + session_id + from + session_state + where + session_id = sessionId::text and + state = 'canceling' + ) limit 1; + get diagnostics rows_affected = row_count; + if rows_affected > 1 then + raise exception 'cancel session: more than one row affected: %', rows_affected; + end if; + end; + $$ language plpgsql; + + -- cancel_session_with_null_fk is intended to be a before update trigger that + -- sets the session's state to cancel if a FK is set to null. + create or replace function + cancel_session_with_null_fk() + returns trigger + as $$ + begin + case + when new.user_id is null then + perform cancel_session(new.public_id); + when new.host_id is null then + perform cancel_session(new.public_id); + when new.target_id is null then + perform cancel_session(new.public_id); + when new.host_set_id is null then + perform cancel_session(new.public_id); + when new.auth_token_id is null then + perform cancel_session(new.public_id); + when new.scope_id is null then + perform cancel_session(new.public_id); + end case; + return new; + end; + $$ language plpgsql; + + create trigger + cancel_session_with_null_fk + before update of user_id, host_id, target_id, host_set_id, auth_token_id, scope_id on session + for each row execute procedure cancel_session_with_null_fk(); + create table session_state_enm ( name text primary key check ( diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index fcbf8775c0..54fe60738f 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -7,9 +7,19 @@ import ( "time" "github.com/golang/protobuf/ptypes" + "github.com/hashicorp/boundary/internal/authtoken" + authtokenStore "github.com/hashicorp/boundary/internal/authtoken/store" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" + "github.com/hashicorp/boundary/internal/host/static" + staticStore "github.com/hashicorp/boundary/internal/host/static/store" + "github.com/hashicorp/boundary/internal/target" + targetStore "github.com/hashicorp/boundary/internal/target/store" + "github.com/lib/pq" + "github.com/hashicorp/boundary/internal/iam" + iamStore "github.com/hashicorp/boundary/internal/iam/store" + "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" "github.com/stretchr/testify/assert" @@ -897,6 +907,186 @@ func TestRepository_CancelSession(t *testing.T) { }) } } + +func TestRepository_CancelSessionViaFKNull(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(t, err) + setupFn := func() *Session { + session := TestDefaultSession(t, conn, wrapper, iamRepo) + _ = TestConnection(t, conn, session.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) + return session + } + type cancelFk struct { + s *Session + fkType interface{} + } + tests := []struct { + name string + cancelFk cancelFk + }{ + { + name: "UserId", + cancelFk: func() cancelFk { + s := setupFn() + t := &iam.User{ + User: &iamStore.User{ + PublicId: s.UserId, + }, + } + return cancelFk{ + s: s, + fkType: t, + } + }(), + }, + { + name: "Host", + cancelFk: func() cancelFk { + s := setupFn() + + t := &static.Host{ + Host: &staticStore.Host{ + PublicId: s.HostId, + }, + } + return cancelFk{ + s: s, + fkType: t, + } + }(), + }, + { + name: "Target", + cancelFk: func() cancelFk { + s := setupFn() + + t := &target.TcpTarget{ + TcpTarget: &targetStore.TcpTarget{ + PublicId: s.TargetId, + }, + } + return cancelFk{ + s: s, + fkType: t, + } + }(), + }, + { + name: "HostSet", + cancelFk: func() cancelFk { + s := setupFn() + + t := &static.HostSet{ + HostSet: &staticStore.HostSet{ + PublicId: s.HostSetId, + }, + } + return cancelFk{ + s: s, + fkType: t, + } + }(), + }, + { + name: "AuthToken", + cancelFk: func() cancelFk { + s := setupFn() + + t := &authtoken.AuthToken{ + AuthToken: &authtokenStore.AuthToken{ + PublicId: s.AuthTokenId, + }, + } + // override the table name so we can delete this thing, since + // it's default table name is a non-writable view. + t.SetTableName("auth_token") + return cancelFk{ + s: s, + fkType: t, + } + }(), + }, + { + name: "Scope", + cancelFk: func() cancelFk { + s := setupFn() + + t := &iam.Scope{ + Scope: &iamStore.Scope{ + PublicId: s.ScopeId, + }, + } + return cancelFk{ + s: s, + fkType: t, + } + }(), + }, + { + name: "canceled-only-once", + cancelFk: func() cancelFk { + s := setupFn() + var err error + s, err = repo.CancelSession(context.Background(), s.PublicId, s.Version) + require.NoError(t, err) + require.Equal(t, StatusCanceling, s.States[0].Status) + + t := &static.Host{ + Host: &staticStore.Host{ + PublicId: s.HostId, + }, + } + return cancelFk{ + s: s, + fkType: t, + } + }(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + s, _, err := repo.LookupSession(context.Background(), tt.cancelFk.s.PublicId) + require.NoError(err) + require.NotNil(s) + require.NotNil(s.States) + + rowsDeleted, err := rw.Delete(context.Background(), tt.cancelFk.fkType) + if err != nil { + var pqError *pq.Error + if errors.As(err, &pqError) { + t.Log(pqError.Message) + t.Log(pqError.Detail) + t.Log(pqError.Where) + t.Log(pqError.Constraint) + t.Log(pqError.Table) + } + } + require.NoError(err) + require.Equal(1, rowsDeleted) + + s, _, err = repo.LookupSession(context.Background(), tt.cancelFk.s.PublicId) + require.NoError(err) + require.NotNil(s) + require.NotNil(s.States) + assert.Equal(StatusCanceling, s.States[0].Status) + canceledCnt := 0 + for _, ss := range s.States { + if ss.Status == StatusCanceling { + canceledCnt += 1 + } + } + assert.Equal(1, canceledCnt) + }) + } +} + func TestRepository_ActivateSession(t *testing.T) { t.Parallel() conn, _ := db.TestSetup(t, "postgres")