terminate "completed" sessions (#477)

pull/520/head
Jim 6 years ago committed by GitHub
parent d851ab07d8
commit 19aecfefae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3426,7 +3426,9 @@ begin;
'closed by end-user',
'terminated',
'network error',
'system error'
'system error',
'connection limit',
'canceled'
)
)
);
@ -3438,7 +3440,9 @@ begin;
('closed by end-user'),
('terminated'),
('network error'),
('system error');
('system error'),
('connection limit'),
('canceled');
create table session (
public_id wt_public_id primary key,
@ -3591,18 +3595,26 @@ begin;
as $$
begin
if new.termination_reason is not null then
-- check to see if there are any open connections.
perform from
session_connection sc,
session_connection_state scs
where
sc.public_id = scs.connection_id and
scs.state != 'closed' and
sc.session_id = new.public_id;
if found then
raise 'session %s has existing open connections', new.public_id;
perform from
session
where
public_id = new.public_id and
public_id not in (
select session_id
from session_connection
where
public_id in (
select connection_id
from session_connection_state
where
state != 'closed' and
end_time is null
)
);
if not found then
raise 'session %s has open connections', new.public_id;
end if;
-- check to see if there's a terminated state already, before inserting a
-- new one.
perform from

@ -71,7 +71,9 @@ begin;
'closed by end-user',
'terminated',
'network error',
'system error'
'system error',
'connection limit',
'canceled'
)
)
);
@ -83,7 +85,9 @@ begin;
('closed by end-user'),
('terminated'),
('network error'),
('system error');
('system error'),
('connection limit'),
('canceled');
create table session (
public_id wt_public_id primary key,
@ -236,18 +240,26 @@ begin;
as $$
begin
if new.termination_reason is not null then
-- check to see if there are any open connections.
perform from
session_connection sc,
session_connection_state scs
where
sc.public_id = scs.connection_id and
scs.state != 'closed' and
sc.session_id = new.public_id;
if found then
raise 'session %s has existing open connections', new.public_id;
perform from
session
where
public_id = new.public_id and
public_id not in (
select session_id
from session_connection
where
public_id in (
select connection_id
from session_connection_state
where
state != 'closed' and
end_time is null
)
);
if not found then
raise 'session %s has open connections', new.public_id;
end if;
-- check to see if there's a terminated state already, before inserting a
-- new one.
perform from

@ -147,6 +147,7 @@ func (c *Controller) Start() error {
c.startStatusTicking(c.baseContext)
c.startRecoveryNonceCleanupTicking(c.baseContext)
c.startTerminateCompletedSessionsTicking(c.baseContext)
c.started.Store(true)
return nil

@ -2,6 +2,7 @@ package controller
import (
"context"
"math/rand"
"time"
"github.com/hashicorp/boundary/internal/servers"
@ -10,7 +11,8 @@ import (
// In the future we could make this configurable
const (
statusInterval = 10 * time.Second
statusInterval = 10 * time.Second
terminationInterval = 1 * time.Minute
)
// This is exported so it can be tweaked in tests
@ -76,3 +78,41 @@ func (c *Controller) startRecoveryNonceCleanupTicking(cancelCtx context.Context)
}
}()
}
func (c *Controller) startTerminateCompletedSessionsTicking(cancelCtx context.Context) {
go func() {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
// desynchronize calls from the controllers, to ease the load on the DB.
getRandomInterval := func() time.Duration {
// 0 to 0.5 adjustment to the base
f := r.Float64() / 2
// Half a chance to be faster, not slower
if r.Float32() > 0.5 {
f = -1 * f
}
return terminationInterval + time.Duration(f*float64(time.Minute))
}
timer := time.NewTimer(0)
for {
select {
case <-cancelCtx.Done():
c.logger.Info("terminating completed sessions ticking shutting down")
return
case <-timer.C:
repo, err := c.SessionRepoFn()
if err != nil {
c.logger.Error("error fetching repository for terminating completed sessions", "error", err)
} else {
terminationCount, err := repo.TerminateCompletedSessions(cancelCtx)
if err != nil {
c.logger.Error("error performing termination of completed sessions", "error", err)
} else if terminationCount > 0 {
c.logger.Info("terminating completed sessions successful", "sessions_terminated", terminationCount)
}
}
timer.Reset(getRandomInterval())
}
}
}()
}

@ -139,5 +139,79 @@ where
s.public_id = ss.public_id
%s
%s
`
// termSessionUpdate is one stmt that terminates sessions for the following
// reasons:
// * sessions that are expired and all their connections are closed.
// * sessions that are canceling and all their connections are closed
// * sessions that have exhausted their connection limit and all their connections are closed.
termSessionsUpdate = `
with canceling_session(session_id) as
(
select
session_id
from
session_state ss
where
ss.state = 'canceling' and
ss.end_time is null
)
update session us
set termination_reason =
case
-- timed out sessions
when now() > us.expiration_time then 'timed out'
-- canceling sessions
when us.public_id in(
select
session_id
from
canceling_session cs
where
us.public_id = cs.session_id
) then 'canceled'
-- default: session connection limit reached.
else 'connection limit'
end
where
termination_reason is null and
-- session expired or connection limit reached
(
-- expired sessions...
now() > us.expiration_time or
-- connection limit reached...
(
select count (*)
from session_connection sc
where
sc.session_id = us.public_id
) >= connection_limit or
-- canceled sessions
us.public_id in (
select
session_id
from
canceling_session cs
where
us.public_id = cs.session_id
)
) and
-- make sure there are no existing connections
us.public_id not in (
select
session_id
from
session_connection
where public_id in (
select
connection_id
from
session_connection_state
where
state != 'closed' and
end_time is null
)
)
`
)

@ -284,11 +284,9 @@ func (r *Repository) CancelSession(ctx context.Context, sessionId string, sessio
return s, nil
}
// TerminateSession sets a session's state to "terminated" in the repo. It's
// called by the worker when the session has been terminated or by a controller
// when all of a session's workers have stopped sending heartbeat status for a
// period of time. Sessions cannot be terminated which still have connections
// that are not closed.
// TerminateSession sets a session's termination reason and it's state to
// "terminated" Sessions cannot be terminated which still have connections that
// are not closed.
func (r *Repository) TerminateSession(ctx context.Context, sessionId string, sessionVersion uint32, reason TerminationReason) (*Session, error) {
if sessionId == "" {
return nil, fmt.Errorf("terminate session: missing session id: %w", db.ErrInvalidParameter)
@ -333,6 +331,33 @@ func (r *Repository) TerminateSession(ctx context.Context, sessionId string, ses
return &updatedSession, nil
}
// TerminateCompletedSessions will terminate sessions in the repo based on:
// * sessions that have exhausted their connection limit and all their connections are closed.
// * sessions that are expired and all their connections are closed.
// * sessions that are canceling and all their connections are closed
// This function should called on a periodic basis a Controllers via it's
// "ticker" pattern.
func (r *Repository) TerminateCompletedSessions(ctx context.Context) (int, error) {
var rowsAffected int
_, err := r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
var err error
rowsAffected, err = w.Exec(ctx, termSessionsUpdate, nil)
if err != nil {
return err
}
return nil
},
)
if err != nil {
return db.NoRowsAffected, fmt.Errorf("terminate completed sessions: %w", err)
}
return rowsAffected, nil
}
// AuthorizeConnection will check to see if a connection is allowed. Currently,
// that authorization checks:
// * the hasn't expired based on the session.Expiration

@ -159,7 +159,6 @@ func TestRepository_ListSession(t *testing.T) {
}
assert.Equal(10, len(testSessions))
withIds := []string{testSessions[0].PublicId, testSessions[1].PublicId}
conn.LogMode(true)
got, err := repo.ListSessions(context.Background(), WithSessionIds(withIds...), WithOrder("create_time asc"))
require.NoError(err)
assert.Equal(2, len(got))
@ -715,6 +714,222 @@ func TestRepository_TerminateSession(t *testing.T) {
}
}
func TestRepository_TerminateCompletedSessions(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(rw, rw, kms)
require.NoError(t, err)
setupFn := func(limit int32, expireIn time.Duration, leaveOpen bool) *Session {
require.NotEqualf(t, int32(0), limit, "setupFn: limit cannot be zero")
exp, err := ptypes.TimestampProto(time.Now().Add(expireIn))
require.NoError(t, err)
composedOf := TestSessionParams(t, conn, wrapper, iamRepo)
composedOf.ConnectionLimit = limit
composedOf.ExpirationTime = &timestamp.Timestamp{Timestamp: exp}
s := TestSession(t, conn, wrapper, composedOf)
srv := TestWorker(t, conn, wrapper)
tofu := TestTofu(t)
s, _, err = repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu)
require.NoError(t, err)
c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 222)
if !leaveOpen {
cw := CloseWith{
ConnectionId: c.PublicId,
BytesUp: 1,
BytesDown: 1,
ClosedReason: ConnectionClosedByUser,
}
_, err = repo.CloseConnections(context.Background(), []CloseWith{cw})
require.NoError(t, err)
}
return s
}
type testArgs struct {
sessions []*Session
wantTermed map[string]TerminationReason
}
tests := []struct {
name string
setup func() testArgs
wantErr bool
}{
{
name: "sessions-with-closed-connections",
setup: func() testArgs {
cnt := 1
wantTermed := map[string]TerminationReason{}
sessions := make([]*Session, 0, 5)
for i := 0; i < cnt; i++ {
// make one with closed connections
s := setupFn(1, time.Hour+1, false)
wantTermed[s.PublicId] = ConnectionLimit
sessions = append(sessions, s)
// make one with connection left open
s2 := setupFn(1, time.Hour+1, true)
sessions = append(sessions, s2)
}
return testArgs{
sessions: sessions,
wantTermed: wantTermed,
}
},
},
{
name: "sessions-with-open-and-closed-connections",
setup: func() testArgs {
cnt := 5
wantTermed := map[string]TerminationReason{}
sessions := make([]*Session, 0, 5)
for i := 0; i < cnt; i++ {
// make one with closed connections
s := setupFn(2, time.Hour+1, false)
_ = TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 222)
sessions = append(sessions, s)
wantTermed[s.PublicId] = ConnectionLimit
}
return testArgs{
sessions: sessions,
wantTermed: nil,
}
},
},
{
name: "sessions-with-no-connections",
setup: func() testArgs {
cnt := 5
sessions := make([]*Session, 0, 5)
for i := 0; i < cnt; i++ {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
sessions = append(sessions, s)
}
return testArgs{
sessions: sessions,
wantTermed: nil,
}
},
},
{
name: "sessions-with-open-connections",
setup: func() testArgs {
cnt := 5
sessions := make([]*Session, 0, 5)
for i := 0; i < cnt; i++ {
s := setupFn(1, time.Hour+1, true)
sessions = append(sessions, s)
}
return testArgs{
sessions: sessions,
wantTermed: nil,
}
},
},
{
name: "expired-sessions",
setup: func() testArgs {
cnt := 5
wantTermed := map[string]TerminationReason{}
sessions := make([]*Session, 0, 5)
for i := 0; i < cnt; i++ {
// make one with closed connections
s := setupFn(1, time.Millisecond+1, false)
// make one with connection left open
s2 := setupFn(1, time.Millisecond+1, true)
sessions = append(sessions, s, s2)
wantTermed[s.PublicId] = TimedOut
}
return testArgs{
sessions: sessions,
wantTermed: wantTermed,
}
},
},
{
name: "canceled-sessions-with-closed-connections",
setup: func() testArgs {
cnt := 1
wantTermed := map[string]TerminationReason{}
sessions := make([]*Session, 0, 5)
for i := 0; i < cnt; i++ {
// make one with limit of 3 and closed connections
s := setupFn(3, time.Hour+1, false)
wantTermed[s.PublicId] = SessionCanceled
sessions = append(sessions, s)
// make one with connection left open
s2 := setupFn(1, time.Hour+1, true)
sessions = append(sessions, s2)
// now cancel the sessions
for _, ses := range sessions {
_, err := repo.CancelSession(context.Background(), ses.PublicId, ses.Version)
require.NoError(t, err)
}
}
return testArgs{
sessions: sessions,
wantTermed: wantTermed,
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
require.NoError(conn.Where("1=1").Delete(AllocSession()).Error)
args := tt.setup()
got, err := repo.TerminateCompletedSessions(context.Background())
if tt.wantErr {
require.Error(err)
return
}
assert.NoError(err)
assert.Equal(len(args.wantTermed), got)
for _, ses := range args.sessions {
found, _, err := repo.LookupSession(context.Background(), ses.PublicId)
require.NoError(err)
_, shouldBeTerminated := args.wantTermed[found.PublicId]
if shouldBeTerminated {
assert.Equal(args.wantTermed[found.PublicId].String(), found.TerminationReason)
t.Logf("terminated %s has a connection limit of %d", found.PublicId, found.ConnectionLimit)
conn, err := repo.ListConnections(context.Background(), found.PublicId)
require.NoError(err)
for _, sc := range conn {
c, cs, err := repo.LookupConnection(context.Background(), sc.PublicId)
require.NoError(err)
assert.NotEmpty(c.ClosedReason)
for _, s := range cs {
t.Logf("%s session %s connection state %s at %s", found.PublicId, s.ConnectionId, s.Status, s.EndTime)
}
}
} else {
t.Logf("not terminated %s has a connection limit of %d", found.PublicId, found.ConnectionLimit)
assert.Equal("", found.TerminationReason)
conn, err := repo.ListConnections(context.Background(), found.PublicId)
require.NoError(err)
for _, sc := range conn {
cs, err := fetchConnectionStates(context.Background(), rw, sc.PublicId)
require.NoError(err)
for _, s := range cs {
t.Logf("%s session %s connection state %s at %s", found.PublicId, s.ConnectionId, s.Status, s.EndTime)
}
}
}
}
})
}
}
func TestRepository_CloseConnections(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")

@ -16,6 +16,8 @@ const (
Terminated TerminationReason = "terminated"
NetworkError TerminationReason = "network error"
SystemError TerminationReason = "system error"
ConnectionLimit TerminationReason = "connection limit"
SessionCanceled TerminationReason = "canceled"
)
// String representation of the termination reason
@ -36,6 +38,8 @@ func convertToReason(s string) (TerminationReason, error) {
return NetworkError, nil
case SystemError.String():
return SystemError, nil
case ConnectionLimit.String():
return ConnectionLimit, nil
default:
return "", fmt.Errorf("termination reason: %s is not a valid reason: %w", s, db.ErrInvalidParameter)
}

Loading…
Cancel
Save