From 19aecfefae471524032681ea181a2b2809d0d83b Mon Sep 17 00:00:00 2001 From: Jim Date: Thu, 1 Oct 2020 10:13:57 -0400 Subject: [PATCH] terminate "completed" sessions (#477) --- internal/db/migrations/postgres.gen.go | 38 +-- .../db/migrations/postgres/50_session.up.sql | 38 +-- internal/servers/controller/controller.go | 1 + internal/servers/controller/tickers.go | 42 +++- internal/session/query.go | 74 ++++++ internal/session/repository_session.go | 35 ++- internal/session/repository_session_test.go | 217 +++++++++++++++++- internal/session/term_reason.go | 4 + 8 files changed, 416 insertions(+), 33 deletions(-) diff --git a/internal/db/migrations/postgres.gen.go b/internal/db/migrations/postgres.gen.go index 1c8e35e375..37f49d2094 100644 --- a/internal/db/migrations/postgres.gen.go +++ b/internal/db/migrations/postgres.gen.go @@ -3426,7 +3426,9 @@ begin; 'closed by end-user', 'terminated', 'network error', - 'system error' + 'system error', + 'connection limit', + 'canceled' ) ) ); @@ -3438,7 +3440,9 @@ begin; ('closed by end-user'), ('terminated'), ('network error'), - ('system error'); + ('system error'), + ('connection limit'), + ('canceled'); create table session ( public_id wt_public_id primary key, @@ -3591,18 +3595,26 @@ begin; as $$ begin if new.termination_reason is not null then - -- check to see if there are any open connections. - perform from - session_connection sc, - session_connection_state scs - where - sc.public_id = scs.connection_id and - scs.state != 'closed' and - sc.session_id = new.public_id; - if found then - raise 'session %s has existing open connections', new.public_id; + perform from + session + where + public_id = new.public_id and + public_id not in ( + select session_id + from session_connection + where + public_id in ( + select connection_id + from session_connection_state + where + state != 'closed' and + end_time is null + ) + ); + if not found then + raise 'session %s has open connections', new.public_id; end if; - + -- check to see if there's a terminated state already, before inserting a -- new one. perform from diff --git a/internal/db/migrations/postgres/50_session.up.sql b/internal/db/migrations/postgres/50_session.up.sql index e253fff6e3..0e92f7c942 100644 --- a/internal/db/migrations/postgres/50_session.up.sql +++ b/internal/db/migrations/postgres/50_session.up.sql @@ -71,7 +71,9 @@ begin; 'closed by end-user', 'terminated', 'network error', - 'system error' + 'system error', + 'connection limit', + 'canceled' ) ) ); @@ -83,7 +85,9 @@ begin; ('closed by end-user'), ('terminated'), ('network error'), - ('system error'); + ('system error'), + ('connection limit'), + ('canceled'); create table session ( public_id wt_public_id primary key, @@ -236,18 +240,26 @@ begin; as $$ begin if new.termination_reason is not null then - -- check to see if there are any open connections. - perform from - session_connection sc, - session_connection_state scs - where - sc.public_id = scs.connection_id and - scs.state != 'closed' and - sc.session_id = new.public_id; - if found then - raise 'session %s has existing open connections', new.public_id; + perform from + session + where + public_id = new.public_id and + public_id not in ( + select session_id + from session_connection + where + public_id in ( + select connection_id + from session_connection_state + where + state != 'closed' and + end_time is null + ) + ); + if not found then + raise 'session %s has open connections', new.public_id; end if; - + -- check to see if there's a terminated state already, before inserting a -- new one. perform from diff --git a/internal/servers/controller/controller.go b/internal/servers/controller/controller.go index ce8c03b071..e44beb0e12 100644 --- a/internal/servers/controller/controller.go +++ b/internal/servers/controller/controller.go @@ -147,6 +147,7 @@ func (c *Controller) Start() error { c.startStatusTicking(c.baseContext) c.startRecoveryNonceCleanupTicking(c.baseContext) + c.startTerminateCompletedSessionsTicking(c.baseContext) c.started.Store(true) return nil diff --git a/internal/servers/controller/tickers.go b/internal/servers/controller/tickers.go index b1440b4987..5168fc5d79 100644 --- a/internal/servers/controller/tickers.go +++ b/internal/servers/controller/tickers.go @@ -2,6 +2,7 @@ package controller import ( "context" + "math/rand" "time" "github.com/hashicorp/boundary/internal/servers" @@ -10,7 +11,8 @@ import ( // In the future we could make this configurable const ( - statusInterval = 10 * time.Second + statusInterval = 10 * time.Second + terminationInterval = 1 * time.Minute ) // This is exported so it can be tweaked in tests @@ -76,3 +78,41 @@ func (c *Controller) startRecoveryNonceCleanupTicking(cancelCtx context.Context) } }() } + +func (c *Controller) startTerminateCompletedSessionsTicking(cancelCtx context.Context) { + go func() { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + // desynchronize calls from the controllers, to ease the load on the DB. + getRandomInterval := func() time.Duration { + // 0 to 0.5 adjustment to the base + f := r.Float64() / 2 + // Half a chance to be faster, not slower + if r.Float32() > 0.5 { + f = -1 * f + } + return terminationInterval + time.Duration(f*float64(time.Minute)) + } + timer := time.NewTimer(0) + for { + select { + case <-cancelCtx.Done(): + c.logger.Info("terminating completed sessions ticking shutting down") + return + + case <-timer.C: + repo, err := c.SessionRepoFn() + if err != nil { + c.logger.Error("error fetching repository for terminating completed sessions", "error", err) + } else { + terminationCount, err := repo.TerminateCompletedSessions(cancelCtx) + if err != nil { + c.logger.Error("error performing termination of completed sessions", "error", err) + } else if terminationCount > 0 { + c.logger.Info("terminating completed sessions successful", "sessions_terminated", terminationCount) + } + } + timer.Reset(getRandomInterval()) + } + } + }() +} diff --git a/internal/session/query.go b/internal/session/query.go index 967a758ec6..23400c1944 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -139,5 +139,79 @@ where s.public_id = ss.public_id %s %s +` + + // termSessionUpdate is one stmt that terminates sessions for the following + // reasons: + // * sessions that are expired and all their connections are closed. + // * sessions that are canceling and all their connections are closed + // * sessions that have exhausted their connection limit and all their connections are closed. + termSessionsUpdate = ` +with canceling_session(session_id) as +( + select + session_id + from + session_state ss + where + ss.state = 'canceling' and + ss.end_time is null +) +update session us + set termination_reason = + case + -- timed out sessions + when now() > us.expiration_time then 'timed out' + -- canceling sessions + when us.public_id in( + select + session_id + from + canceling_session cs + where + us.public_id = cs.session_id + ) then 'canceled' + -- default: session connection limit reached. + else 'connection limit' + end +where + termination_reason is null and + -- session expired or connection limit reached + ( + -- expired sessions... + now() > us.expiration_time or + -- connection limit reached... + ( + select count (*) + from session_connection sc + where + sc.session_id = us.public_id + ) >= connection_limit or + -- canceled sessions + us.public_id in ( + select + session_id + from + canceling_session cs + where + us.public_id = cs.session_id + ) + ) and + -- make sure there are no existing connections + us.public_id not in ( + select + session_id + from + session_connection + where public_id in ( + select + connection_id + from + session_connection_state + where + state != 'closed' and + end_time is null + ) +) ` ) diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index f877ac0690..49af3b7618 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -284,11 +284,9 @@ func (r *Repository) CancelSession(ctx context.Context, sessionId string, sessio return s, nil } -// TerminateSession sets a session's state to "terminated" in the repo. It's -// called by the worker when the session has been terminated or by a controller -// when all of a session's workers have stopped sending heartbeat status for a -// period of time. Sessions cannot be terminated which still have connections -// that are not closed. +// TerminateSession sets a session's termination reason and it's state to +// "terminated" Sessions cannot be terminated which still have connections that +// are not closed. func (r *Repository) TerminateSession(ctx context.Context, sessionId string, sessionVersion uint32, reason TerminationReason) (*Session, error) { if sessionId == "" { return nil, fmt.Errorf("terminate session: missing session id: %w", db.ErrInvalidParameter) @@ -333,6 +331,33 @@ func (r *Repository) TerminateSession(ctx context.Context, sessionId string, ses return &updatedSession, nil } +// TerminateCompletedSessions will terminate sessions in the repo based on: +// * sessions that have exhausted their connection limit and all their connections are closed. +// * sessions that are expired and all their connections are closed. +// * sessions that are canceling and all their connections are closed +// This function should called on a periodic basis a Controllers via it's +// "ticker" pattern. +func (r *Repository) TerminateCompletedSessions(ctx context.Context) (int, error) { + var rowsAffected int + _, err := r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + var err error + rowsAffected, err = w.Exec(ctx, termSessionsUpdate, nil) + if err != nil { + return err + } + return nil + }, + ) + if err != nil { + return db.NoRowsAffected, fmt.Errorf("terminate completed sessions: %w", err) + } + return rowsAffected, nil +} + // AuthorizeConnection will check to see if a connection is allowed. Currently, // that authorization checks: // * the hasn't expired based on the session.Expiration diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index 54fe60738f..19b925f339 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -159,7 +159,6 @@ func TestRepository_ListSession(t *testing.T) { } assert.Equal(10, len(testSessions)) withIds := []string{testSessions[0].PublicId, testSessions[1].PublicId} - conn.LogMode(true) got, err := repo.ListSessions(context.Background(), WithSessionIds(withIds...), WithOrder("create_time asc")) require.NoError(err) assert.Equal(2, len(got)) @@ -715,6 +714,222 @@ func TestRepository_TerminateSession(t *testing.T) { } } +func TestRepository_TerminateCompletedSessions(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(t, err) + + setupFn := func(limit int32, expireIn time.Duration, leaveOpen bool) *Session { + require.NotEqualf(t, int32(0), limit, "setupFn: limit cannot be zero") + exp, err := ptypes.TimestampProto(time.Now().Add(expireIn)) + require.NoError(t, err) + composedOf := TestSessionParams(t, conn, wrapper, iamRepo) + composedOf.ConnectionLimit = limit + composedOf.ExpirationTime = ×tamp.Timestamp{Timestamp: exp} + s := TestSession(t, conn, wrapper, composedOf) + + srv := TestWorker(t, conn, wrapper) + tofu := TestTofu(t) + s, _, err = repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu) + require.NoError(t, err) + c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 222) + if !leaveOpen { + cw := CloseWith{ + ConnectionId: c.PublicId, + BytesUp: 1, + BytesDown: 1, + ClosedReason: ConnectionClosedByUser, + } + _, err = repo.CloseConnections(context.Background(), []CloseWith{cw}) + require.NoError(t, err) + } + return s + } + + type testArgs struct { + sessions []*Session + wantTermed map[string]TerminationReason + } + tests := []struct { + name string + setup func() testArgs + wantErr bool + }{ + { + name: "sessions-with-closed-connections", + setup: func() testArgs { + cnt := 1 + wantTermed := map[string]TerminationReason{} + sessions := make([]*Session, 0, 5) + for i := 0; i < cnt; i++ { + // make one with closed connections + s := setupFn(1, time.Hour+1, false) + wantTermed[s.PublicId] = ConnectionLimit + sessions = append(sessions, s) + + // make one with connection left open + s2 := setupFn(1, time.Hour+1, true) + sessions = append(sessions, s2) + } + return testArgs{ + sessions: sessions, + wantTermed: wantTermed, + } + }, + }, + { + name: "sessions-with-open-and-closed-connections", + setup: func() testArgs { + cnt := 5 + wantTermed := map[string]TerminationReason{} + sessions := make([]*Session, 0, 5) + for i := 0; i < cnt; i++ { + // make one with closed connections + s := setupFn(2, time.Hour+1, false) + _ = TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 222) + sessions = append(sessions, s) + wantTermed[s.PublicId] = ConnectionLimit + } + return testArgs{ + sessions: sessions, + wantTermed: nil, + } + }, + }, + { + name: "sessions-with-no-connections", + setup: func() testArgs { + cnt := 5 + sessions := make([]*Session, 0, 5) + for i := 0; i < cnt; i++ { + s := TestDefaultSession(t, conn, wrapper, iamRepo) + sessions = append(sessions, s) + } + return testArgs{ + sessions: sessions, + wantTermed: nil, + } + }, + }, + { + name: "sessions-with-open-connections", + setup: func() testArgs { + cnt := 5 + sessions := make([]*Session, 0, 5) + for i := 0; i < cnt; i++ { + s := setupFn(1, time.Hour+1, true) + sessions = append(sessions, s) + } + return testArgs{ + sessions: sessions, + wantTermed: nil, + } + }, + }, + { + name: "expired-sessions", + setup: func() testArgs { + cnt := 5 + wantTermed := map[string]TerminationReason{} + sessions := make([]*Session, 0, 5) + for i := 0; i < cnt; i++ { + // make one with closed connections + s := setupFn(1, time.Millisecond+1, false) + // make one with connection left open + s2 := setupFn(1, time.Millisecond+1, true) + sessions = append(sessions, s, s2) + wantTermed[s.PublicId] = TimedOut + } + return testArgs{ + sessions: sessions, + wantTermed: wantTermed, + } + }, + }, + { + name: "canceled-sessions-with-closed-connections", + setup: func() testArgs { + cnt := 1 + wantTermed := map[string]TerminationReason{} + sessions := make([]*Session, 0, 5) + for i := 0; i < cnt; i++ { + // make one with limit of 3 and closed connections + s := setupFn(3, time.Hour+1, false) + wantTermed[s.PublicId] = SessionCanceled + sessions = append(sessions, s) + + // make one with connection left open + s2 := setupFn(1, time.Hour+1, true) + sessions = append(sessions, s2) + + // now cancel the sessions + for _, ses := range sessions { + _, err := repo.CancelSession(context.Background(), ses.PublicId, ses.Version) + require.NoError(t, err) + } + } + return testArgs{ + sessions: sessions, + wantTermed: wantTermed, + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + require.NoError(conn.Where("1=1").Delete(AllocSession()).Error) + args := tt.setup() + + got, err := repo.TerminateCompletedSessions(context.Background()) + if tt.wantErr { + require.Error(err) + return + } + assert.NoError(err) + assert.Equal(len(args.wantTermed), got) + + for _, ses := range args.sessions { + found, _, err := repo.LookupSession(context.Background(), ses.PublicId) + require.NoError(err) + _, shouldBeTerminated := args.wantTermed[found.PublicId] + if shouldBeTerminated { + assert.Equal(args.wantTermed[found.PublicId].String(), found.TerminationReason) + t.Logf("terminated %s has a connection limit of %d", found.PublicId, found.ConnectionLimit) + conn, err := repo.ListConnections(context.Background(), found.PublicId) + require.NoError(err) + for _, sc := range conn { + c, cs, err := repo.LookupConnection(context.Background(), sc.PublicId) + require.NoError(err) + assert.NotEmpty(c.ClosedReason) + for _, s := range cs { + t.Logf("%s session %s connection state %s at %s", found.PublicId, s.ConnectionId, s.Status, s.EndTime) + } + } + } else { + t.Logf("not terminated %s has a connection limit of %d", found.PublicId, found.ConnectionLimit) + assert.Equal("", found.TerminationReason) + conn, err := repo.ListConnections(context.Background(), found.PublicId) + require.NoError(err) + for _, sc := range conn { + cs, err := fetchConnectionStates(context.Background(), rw, sc.PublicId) + require.NoError(err) + for _, s := range cs { + t.Logf("%s session %s connection state %s at %s", found.PublicId, s.ConnectionId, s.Status, s.EndTime) + } + } + } + } + + }) + } +} + func TestRepository_CloseConnections(t *testing.T) { t.Parallel() conn, _ := db.TestSetup(t, "postgres") diff --git a/internal/session/term_reason.go b/internal/session/term_reason.go index 5b86aeb5a0..b8e21809c0 100644 --- a/internal/session/term_reason.go +++ b/internal/session/term_reason.go @@ -16,6 +16,8 @@ const ( Terminated TerminationReason = "terminated" NetworkError TerminationReason = "network error" SystemError TerminationReason = "system error" + ConnectionLimit TerminationReason = "connection limit" + SessionCanceled TerminationReason = "canceled" ) // String representation of the termination reason @@ -36,6 +38,8 @@ func convertToReason(s string) (TerminationReason, error) { return NetworkError, nil case SystemError.String(): return SystemError, nil + case ConnectionLimit.String(): + return ConnectionLimit, nil default: return "", fmt.Errorf("termination reason: %s is not a valid reason: %w", s, db.ErrInvalidParameter) }