internal/servers/controller: Worker failure connection cleanup (#1340)

This commit adds the support to do the following:

* Mark connections for non-reporting workers as closed. This is the
controller counterpart to the worker functionality (see #1330). This is
written as a scheduled job that does the work DB-side in a single atomic
query.

* Works to reconcile states if such a broken controller-worker
connection resumes and a worker reports a connection as connected that
should be disconnected. In this case, the controller will send an update
request, and the worker will honor it and terminate the connection.

* Further refinement of the grace period setting has been added here.
We have converged on the current server "liveness" setting as our
default here, which is half of the previous 30s (15 seconds, in other
words). Additionally, this is now configurable on the controller and
worker side, with the caveat that it's currently impossible to do so in
config as the setting has been untagged in HCL. This is exposed so that
we can run some sophisticated testing scenarios where we skew the grace
period to either the controller or worker to ensure the aforementioned
reconciliation works.

* Some repository functions have been added to support the new
functionality, in addition to some test code to the worker to allow
querying of session state while testing.

* Finally, we've added some add timestamp subtraction functions
as well, basically serving as the opposite of the addition functions.
pull/1397/head
Chris Marchesi 5 years ago committed by GitHub
parent 462306fea5
commit 5a70875726
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,6 +15,7 @@ import (
"strings"
"sync"
"syscall"
"time"
"github.com/armon/go-metrics"
berrors "github.com/hashicorp/boundary/internal/errors"
@ -24,6 +25,7 @@ import (
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/observability/event"
"github.com/hashicorp/boundary/internal/servers"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/boundary/version"
"github.com/hashicorp/go-hclog"
@ -40,6 +42,23 @@ import (
"google.golang.org/grpc/grpclog"
)
const (
// defaultStatusGracePeriod is the default status grace period, or the period
// of time that we will go without a status report before we start
// disconnecting and marking connections as closed. This is tied to the
// server default liveness setting, a related value. See the servers package
// for more details.
defaultStatusGracePeriod = servers.DefaultLiveness
// statusGracePeriodEnvVar is the environment variable that can be used to
// configure the status grace period. This setting is provided in seconds,
// and can never be lower than the default status grace period defined above.
//
// TODO: This value is temporary, it will be removed once we have a better
// story/direction on attributes and system defaults.
statusGracePeriodEnvVar = "BOUNDARY_STATUS_GRACE_PERIOD"
)
type Server struct {
*Command
@ -101,6 +120,11 @@ type Server struct {
DevDatabaseCleanupFunc func() error
Database *gorm.DB
// StatusGracePeriodDuration represents the period of time (as a
// duration) that the controller will wait before marking
// connections from a disconnected worker as invalid.
StatusGracePeriodDuration time.Duration
}
func NewServer(cmd *Command) *Server {
@ -675,3 +699,51 @@ func MakeSighupCh() chan struct{} {
}()
return resultCh
}
// SetStatusGracePeriodDuration sets the value for
// StatusGracePeriodDuration.
//
// The grace period is the length of time we allow connections to run
// on a worker in the event of an error sending status updates. The
// period is defined the length of time since the last successful
// update.
//
// The setting is derived from one of the following, in order:
//
// * Via the supplied value if non-zero.
// * BOUNDARY_STATUS_GRACE_PERIOD, if defined, can be set to an
// integer value to define the setting.
// * If either of these is missing, the default is used. See the
// defaultStatusGracePeriod value for the default value.
//
// The minimum setting for this value is the default setting. Values
// below this will be reset to the default.
func (s *Server) SetStatusGracePeriodDuration(value time.Duration) {
var result time.Duration
switch {
case value > 0:
result = value
case os.Getenv(statusGracePeriodEnvVar) != "":
// TODO: See the description of the constant for more details on
// this env var
v := os.Getenv(statusGracePeriodEnvVar)
n, err := strconv.Atoi(v)
if err != nil {
s.Logger.Error(fmt.Sprintf("could not read setting for %s", statusGracePeriodEnvVar),
"err", err,
"value", v,
)
break
}
result = time.Second * time.Duration(n)
}
if result < defaultStatusGracePeriod {
s.Logger.Debug("invalid grace period setting or none provided, using default", "value", result, "default", defaultStatusGracePeriod)
result = defaultStatusGracePeriod
}
s.Logger.Debug("session cleanup in effect, connections will be terminated if status reports cannot be made", "grace_period", result)
s.StatusGracePeriodDuration = result
}

@ -425,6 +425,10 @@ func (c *Command) Run(args []string) int {
return base.CommandCliError
}
// Initialize status grace period (0 denotes using env or default
// here)
c.SetStatusGracePeriodDuration(0)
base.StartMemProfiler(c.Logger)
if err := c.SetupMetrics(c.UI, c.Config.Telemetry); err != nil {

@ -156,6 +156,10 @@ func (c *Command) Run(args []string) int {
return base.CommandUserError
}
// Initialize status grace period (0 denotes using env or default
// here)
c.SetStatusGracePeriodDuration(0)
base.StartMemProfiler(c.Logger)
if !c.skipMetrics {

@ -119,6 +119,13 @@ type Controller struct {
// denoted by time.Duration
AuthTokenTimeToStale interface{} `hcl:"auth_token_time_to_stale"`
AuthTokenTimeToStaleDuration time.Duration
// StatusGracePeriod represents the period of time (as a duration) that the
// controller will wait before marking connections from a disconnected worker
// as invalid.
//
// TODO: This field is currently internal.
StatusGracePeriodDuration time.Duration `hcl:"-"`
}
type Worker struct {
@ -132,6 +139,13 @@ type Worker struct {
// key=value syntax. This is trued up in the Parse function below.
TagsRaw interface{} `hcl:"tags"`
Tags map[string][]string `hcl:"-"`
// StatusGracePeriod represents the period of time (as a duration) that the
// worker will wait before disconnecting connections if it cannot make a
// status report to a controller.
//
// TODO: This field is currently internal.
StatusGracePeriodDuration time.Duration `hcl:"-"`
}
type Database struct {

@ -0,0 +1,23 @@
begin;
create function wt_sub_seconds(sec integer, ts timestamp with time zone)
returns timestamp with time zone
as $$
select ts - sec * '1 second'::interval;
$$ language sql
stable
returns null on null input;
comment on function wt_add_seconds is
'wt_sub_seconds returns ts - sec.';
create function wt_sub_seconds_from_now(sec integer)
returns timestamp with time zone
as $$
select wt_sub_seconds(sec, current_timestamp);
$$ language sql
stable
returns null on null input;
comment on function wt_add_seconds_to_now is
'wt_sub_seconds_from_now returns current_timestamp - sec.';
commit;

@ -0,0 +1,61 @@
package migration
import (
"context"
"database/sql"
"testing"
"time"
"github.com/hashicorp/boundary/internal/db/schema"
"github.com/hashicorp/boundary/internal/docker"
"github.com/stretchr/testify/require"
)
const targetMigration = 12001
func TestWtSubSeconds(t *testing.T) {
t.Parallel()
require := require.New(t)
ctx := context.Background()
d := testSetupDb(ctx, t)
// Test by subtracing a day from the test date
sourceTime, err := time.Parse(time.RFC3339, "2006-01-02T15:04:05+07:00")
require.NoError(err)
expectedTime := sourceTime.Add(time.Second * -86400)
var actualTime time.Time
row := d.QueryRowContext(ctx, "select wt_sub_seconds($1, $2)", 86400, sourceTime)
require.NoError(row.Scan(&actualTime))
require.True(expectedTime.Equal(actualTime))
}
func testSetupDb(ctx context.Context, t *testing.T) *sql.DB {
t.Helper()
require := require.New(t)
dialect := "postgres"
c, u, _, err := docker.StartDbInDocker(dialect)
require.NoError(err)
t.Cleanup(func() {
require.NoError(c())
})
d, err := sql.Open(dialect, u)
require.NoError(err)
oState := schema.TestCloneMigrationStates(t)
nState := schema.TestCreatePartialMigrationState(oState["postgres"], targetMigration)
oState["postgres"] = nState
m, err := schema.NewManager(ctx, dialect, d, schema.WithMigrationStates(oState))
require.NoError(err)
require.NoError(m.RollForward(ctx))
state, err := m.CurrentState(ctx)
require.NoError(err)
require.Equal(targetMigration, state.DatabaseSchemaVersion)
require.False(state.Dirty)
return d
}

@ -4,7 +4,7 @@ package schema
func init() {
migrationStates["postgres"] = migrationState{
binarySchemaVersion: 11001,
binarySchemaVersion: 12001,
upMigrations: map[int][]byte{
1: []byte(`
create domain wt_public_id as text
@ -6100,6 +6100,27 @@ alter table server
foreign key (type) references server_type_enm(name)
on update cascade
on delete restrict;
`),
12001: []byte(`
create function wt_sub_seconds(sec integer, ts timestamp with time zone)
returns timestamp with time zone
as $$
select ts - sec * '1 second'::interval;
$$ language sql
stable
returns null on null input;
comment on function wt_add_seconds is
'wt_sub_seconds returns ts - sec.';
create function wt_sub_seconds_from_now(sec integer)
returns timestamp with time zone
as $$
select wt_sub_seconds(sec, current_timestamp);
$$ language sql
stable
returns null on null input;
comment on function wt_add_seconds_to_now is
'wt_sub_seconds_from_now returns current_timestamp - sec.';
`),
2001: []byte(`
-- log_migration entries represent logs generated during migrations

@ -38,7 +38,7 @@ type Controller struct {
workerAuthCache *cache.Cache
// Used for testing
// Used for testing and tracking worker health
workerStatusUpdateTimes *sync.Map
// Repo factory methods
@ -117,7 +117,9 @@ func New(conf *Config) (*Controller, error) {
jobRepoFn := func() (*job.Repository, error) {
return job.NewRepository(dbase, dbase, c.kms)
}
c.scheduler, err = scheduler.New(c.conf.RawConfig.Controller.Name, jobRepoFn, c.logger)
// TODO: the RunJobsLimit is temporary until a better fix gets in. This
// currently caps the scheduler at running 10 jobs per interval.
c.scheduler, err = scheduler.New(c.conf.RawConfig.Controller.Name, jobRepoFn, c.logger, scheduler.WithRunJobsLimit(10))
if err != nil {
return nil, fmt.Errorf("error creating new scheduler: %w", err)
}
@ -182,7 +184,29 @@ func (c *Controller) Start() error {
func (c *Controller) registerJobs() error {
rw := db.New(c.conf.Database)
return vault.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms, c.logger)
if err := vault.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms, c.logger); err != nil {
return err
}
if err := c.registerSessionCleanupJob(); err != nil {
return err
}
return nil
}
// registerSessionCleanupJob is a helper method to abstract
// registering the session cleanup job specifically.
func (c *Controller) registerSessionCleanupJob() error {
sessionCleanupJob, err := newSessionCleanupJob(c.logger, c.SessionRepoFn, int(c.conf.StatusGracePeriodDuration.Seconds()))
if err != nil {
return fmt.Errorf("error creating session cleanup job: %w", err)
}
if err = c.scheduler.RegisterJob(c.baseContext, sessionCleanupJob); err != nil {
return fmt.Errorf("error registering session cleanup job: %w", err)
}
return nil
}
func (c *Controller) Shutdown(serversOnly bool) error {

@ -76,12 +76,19 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques
Controllers: controllers,
}
// TODO (jeff, 05-2021): We should possibly list all connections here with
// their statuses, and check those against the found connections below. If
// any we think should be closed are marked open by the worker, the worker
// should be told to close them.
var foundConns []string
var (
// For tracking the reported open connections.
reportedOpenConns []string
// For tracking the session IDs we've already requested
// cancellation for. We won't need to add connection cancel
// requests for these because canceling the session terminates the
// connections.
requestedSessionCancelIds []string
)
// This is a map of all sessions and their statuses. We keep track of
// this for easy lookup if we need to make change requests.
sessionStatuses := make(map[string]pbs.SESSIONSTATUS)
for _, jobStatus := range req.GetJobs() {
switch jobStatus.Job.GetType() {
@ -92,6 +99,9 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques
return nil, status.Error(codes.Internal, "Error getting session info at status time")
}
// Record status.
sessionStatuses[si.GetSessionId()] = si.Status
// Check connections before potentially bypassing the rest of the
// logic in the switch on si.Status.
sessConns := si.GetConnections()
@ -103,7 +113,7 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques
// report as found, so that we should attempt to close it.
// Note that unspecified is the default state for the enum
// but it's not ever explicitly set by us.
foundConns = append(foundConns, conn.GetConnectionId())
reportedOpenConns = append(reportedOpenConns, conn.GetConnectionId())
}
}
@ -127,7 +137,7 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques
}
// If the session from the DB is in canceling status, and we're
// here, it means the job is in pending or active; cancel it. If
// it's in termianted status something went wrong and we're
// it's in terminated status something went wrong and we're
// mismatched, so ensure we cancel it also.
currState := sessionInfo.States[0].Status
if currState.ProtoVal() != si.Status {
@ -148,13 +158,53 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques
},
RequestType: pbs.CHANGETYPE_CHANGETYPE_UPDATE_STATE,
})
// Log the session ID so we don't add a duplicate change
// request on connection normalization.
requestedSessionCancelIds = append(requestedSessionCancelIds, sessionId)
}
}
}
}
// Run our connection cleanup function
closedConns, err := sessRepo.CloseDeadConnectionsOnWorkerReport(ctx, req.Worker.PrivateId, foundConns)
// Normalize the current state of connections on the worker side
// with the data from the controller. In other words, if one of our
// found connections isn't supposed to be alive still, kill it.
//
// This is separate from the above session normalization and is
// additive to it, we don't add sessions that have already been
// added there as canceling sessions already closes the
// connections.
shouldCloseConnections, err := sessRepo.ShouldCloseConnectionsOnWorker(ctx, reportedOpenConns, requestedSessionCancelIds)
if err != nil {
return nil, status.Errorf(codes.Internal, "Error fetching connections that should be closed: %v", err)
}
for sessionId, connIds := range shouldCloseConnections {
var connChanges []*pbs.Connection
for _, connId := range connIds {
connChanges = append(connChanges, &pbs.Connection{
ConnectionId: connId,
Status: session.StatusClosed.ProtoVal(),
})
}
ret.JobsRequests = append(ret.JobsRequests, &pbs.JobChangeRequest{
Job: &pbs.Job{
Type: pbs.JOBTYPE_JOBTYPE_SESSION,
JobInfo: &pbs.Job_SessionInfo{
SessionInfo: &pbs.SessionJobInfo{
SessionId: sessionId,
Status: sessionStatuses[sessionId],
Connections: connChanges,
},
},
},
RequestType: pbs.CHANGETYPE_CHANGETYPE_UPDATE_STATE,
})
}
// Run our controller-side cleanup function.
closedConns, err := sessRepo.CloseDeadConnectionsForWorker(ctx, req.Worker.PrivateId, reportedOpenConns)
if err != nil {
return nil, status.Errorf(codes.Internal, "Error closing dead conns for worker %s: %v", req.Worker.PrivateId, err)
}

@ -0,0 +1,133 @@
package controller
import (
"context"
"fmt"
"time"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/scheduler"
"github.com/hashicorp/boundary/internal/servers/controller/common"
"github.com/hashicorp/boundary/internal/session"
"github.com/hashicorp/go-hclog"
)
// sessionCleanupJob defines a periodic job that monitors workers for
// loss of connection and terminates connections on workers that have
// not sent a heartbeat in a significant period of time.
//
// Relevant connections are simply marked as disconnected in the
// database. Connections will be independently terminated by the
// worker, or the event of a synchronization issue between the two,
// the controller will win out and order that the connections be
// closed on the worker.
type sessionCleanupJob struct {
logger hclog.Logger
sessionRepoFn common.SessionRepoFactory
// The amount of time to give disconnected workers before marking
// their connections as closed.
gracePeriod int
// The total number of connections closed in the last run.
totalClosed int
}
// newSessionCleanupJob instantiates the session cleanup job.
func newSessionCleanupJob(
logger hclog.Logger,
sessionRepoFn common.SessionRepoFactory,
gracePeriod int,
) (*sessionCleanupJob, error) {
const op = "controller.newNewSessionCleanupJob"
switch {
case logger == nil:
return nil, errors.New(errors.InvalidParameter, op, "missing logger")
case sessionRepoFn == nil:
return nil, errors.New(errors.InvalidParameter, op, "missing sessionRepoFn")
case gracePeriod < session.DeadWorkerConnCloseMinGrace:
return nil, errors.New(
errors.InvalidParameter, op, fmt.Sprintf("invalid gracePeriod, must be greater than %d", session.DeadWorkerConnCloseMinGrace))
}
return &sessionCleanupJob{
logger: logger,
sessionRepoFn: sessionRepoFn,
gracePeriod: gracePeriod,
}, nil
}
// Name returns a short, unique name for the job.
func (j *sessionCleanupJob) Name() string { return "session_cleanup" }
// Description returns the description for the job.
func (j *sessionCleanupJob) Description() string {
return "Clean up session connections from disconnected workers"
}
// NextRunIn returns the next run time after a job is completed.
//
// The next run time is defined for sessionCleanupJob as one second.
// This is because the job should run continuously to terminate
// connections as soon as a worker has not reported in for a long
// enough time. Only one job will ever run at once, so there is no
// reason why it cannot run again immediately.
func (j *sessionCleanupJob) NextRunIn() (time.Duration, error) { return time.Second, nil }
// Status returns the status of the running job.
func (j *sessionCleanupJob) Status() scheduler.JobStatus {
return scheduler.JobStatus{
Completed: j.totalClosed,
Total: j.totalClosed,
}
}
// Run executes the job.
func (j *sessionCleanupJob) Run(ctx context.Context) error {
const op = "controller.(sessionCleanupJob).Run"
j.logger.Debug(
"starting job",
"op", op,
)
j.totalClosed = 0
// Load repos.
sessionRepo, err := j.sessionRepoFn()
if err != nil {
return errors.Wrap(err, op, errors.WithMsg("error getting session repo"))
}
// Run the atomic dead worker cleanup job.
results, err := sessionRepo.CloseConnectionsForDeadWorkers(ctx, j.gracePeriod)
if err != nil {
return errors.Wrap(err, op)
}
if len(results) < 1 {
j.logger.Debug(
"all workers OK, no connections to close",
"op", op,
)
} else {
for _, result := range results {
// Log a closed connection message for each worker as a warning
j.logger.Warn(
"worker has not reported status within acceptable grace period, all connections closed",
"op", op,
"private_id", result.ServerId,
"update_time", result.LastUpdateTime,
"grace_period_seconds", j.gracePeriod,
"number_connections_closed", result.NumberConnectionsClosed,
)
j.totalClosed += result.NumberConnectionsClosed
}
}
j.logger.Debug(
"job finished",
"op", op,
"total_connections_closed", j.totalClosed,
)
return nil
}

@ -0,0 +1,170 @@
package controller
import (
"context"
"fmt"
"testing"
"time"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/scheduler"
"github.com/hashicorp/boundary/internal/servers"
"github.com/hashicorp/boundary/internal/session"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// assert the interface
var _ = scheduler.Job(new(sessionCleanupJob))
// This test has been largely adapted from
// TestRepository_CloseDeadConnectionsOnWorker in
// internal/session/repository_connection_test.go.
func TestSessionCleanupJob(t *testing.T) {
t.Parallel()
require, assert := require.New(t), assert.New(t)
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
kms := kms.TestKms(t, conn, wrapper)
serversRepo, err := servers.NewRepository(rw, rw, kms)
require.NoError(err)
sessionRepo, err := session.NewRepository(rw, rw, kms)
require.NoError(err)
ctx := context.Background()
numConns := 12
// Create two "workers". One will remain untouched while the other "goes
// away and comes back" (worker 2).
worker1 := session.TestWorker(t, conn, wrapper, session.WithServerId("worker1"))
worker2 := session.TestWorker(t, conn, wrapper, session.WithServerId("worker2"))
// Create a few sessions on each, activate, and authorize a connection
var connIds []string
connIdsByWorker := make(map[string][]string)
for i := 0; i < numConns; i++ {
serverId := worker1.PrivateId
if i%2 == 0 {
serverId = worker2.PrivateId
}
sess := session.TestDefaultSession(t, conn, wrapper, iamRepo, session.WithServerId(serverId), session.WithDbOpts(db.WithSkipVetForWrite(true)))
sess, _, err = sessionRepo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, serverId, "worker", []byte("foo"))
require.NoError(err)
c, cs, _, err := sessionRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId)
require.NoError(err)
require.Len(cs, 1)
require.Equal(session.StatusAuthorized, cs[0].Status)
connIds = append(connIds, c.GetPublicId())
if i%2 == 0 {
connIdsByWorker[worker2.PrivateId] = append(connIdsByWorker[worker2.PrivateId], c.GetPublicId())
} else {
connIdsByWorker[worker1.PrivateId] = append(connIdsByWorker[worker1.PrivateId], c.GetPublicId())
}
}
// Mark half of the connections connected and leave the others authorized.
// This is just to ensure we have a spread when we test it out.
for i, connId := range connIds {
if i%2 == 0 {
_, cs, err := sessionRepo.ConnectConnection(ctx, session.ConnectWith{
ConnectionId: connId,
ClientTcpAddress: "127.0.0.1",
ClientTcpPort: 22,
EndpointTcpAddress: "127.0.0.1",
EndpointTcpPort: 22,
})
require.NoError(err)
require.Len(cs, 2)
var foundAuthorized, foundConnected bool
for _, status := range cs {
if status.Status == session.StatusAuthorized {
foundAuthorized = true
}
if status.Status == session.StatusConnected {
foundConnected = true
}
}
require.True(foundAuthorized)
require.True(foundConnected)
}
}
// Create the job.
job, err := newSessionCleanupJob(
hclog.New(&hclog.LoggerOptions{Level: hclog.Trace}),
func() (*session.Repository, error) { return sessionRepo, nil },
session.DeadWorkerConnCloseMinGrace,
)
require.NoError(err)
// sleep the status grace period.
time.Sleep(time.Second * time.Duration(session.DeadWorkerConnCloseMinGrace))
// Push an upsert to the first worker so that its status has been
// updated.
_, rowsUpdated, err := serversRepo.UpsertServer(ctx, worker1, []servers.Option{}...)
require.NoError(err)
require.Equal(1, rowsUpdated)
// Run the job.
require.NoError(job.Run(ctx))
// Assert connection state on both workers.
assertConnections := func(workerId string, closed bool) {
connIds, ok := connIdsByWorker[workerId]
require.True(ok)
require.Len(connIds, 6)
for _, connId := range connIds {
_, states, err := sessionRepo.LookupConnection(ctx, connId, nil)
require.NoError(err)
var foundClosed bool
for _, state := range states {
if state.Status == session.StatusClosed {
foundClosed = true
break
}
}
assert.Equal(closed, foundClosed)
}
}
// Assert that all connections on the second worker are closed
assertConnections(worker2.PrivateId, true)
// Assert that all connections on the first worker are still open
assertConnections(worker1.PrivateId, false)
}
func TestSessionCleanupJobNewJobErr(t *testing.T) {
t.Parallel()
const op = "controller.newNewSessionCleanupJob"
require := require.New(t)
job, err := newSessionCleanupJob(nil, nil, 0)
require.Equal(err, errors.E(
errors.WithCode(errors.InvalidParameter),
errors.WithOp(op),
errors.WithMsg("missing logger"),
))
require.Nil(job)
job, err = newSessionCleanupJob(hclog.New(nil), nil, 0)
require.Equal(err, errors.E(
errors.WithCode(errors.InvalidParameter),
errors.WithOp(op),
errors.WithMsg("missing sessionRepoFn"),
))
require.Nil(job)
job, err = newSessionCleanupJob(hclog.New(nil), func() (*session.Repository, error) { return nil, nil }, 0)
require.Equal(err, errors.E(
errors.WithCode(errors.InvalidParameter),
errors.WithOp(op),
errors.WithMsg(fmt.Sprintf("invalid gracePeriod, must be greater than %d", session.DeadWorkerConnCloseMinGrace)),
))
require.Nil(job)
}

@ -8,6 +8,7 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/hashicorp/boundary/api"
"github.com/hashicorp/boundary/api/authmethods"
@ -371,6 +372,10 @@ type TestControllerOpts struct {
// A cluster address for overriding the advertised controller listener
// (overrides address provided in config, if any)
PublicClusterAddr string
// The amount of time to wait before marking connections as closed when a
// worker has not reported in
StatusGracePeriodDuration time.Duration
}
func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
@ -456,6 +461,9 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
t.Fatal(err)
}
// Initialize status grace period
tc.b.SetStatusGracePeriodDuration(opts.StatusGracePeriodDuration)
if opts.Config.Controller == nil {
opts.Config.Controller = new(config.Controller)
}
@ -615,6 +623,7 @@ func (tc *TestController) AddClusterControllerMember(t *testing.T, opts *TestCon
DisableKmsKeyCreation: true,
DisableAuthMethodCreation: true,
PublicClusterAddr: opts.PublicClusterAddr,
StatusGracePeriodDuration: opts.StatusGracePeriodDuration,
}
if opts.Logger != nil {
nextOpts.Logger = opts.Logger
@ -629,3 +638,71 @@ func (tc *TestController) AddClusterControllerMember(t *testing.T, opts *TestCon
}
return NewTestController(t, nextOpts)
}
// WaitForNextWorkerStatusUpdate waits for the next status check from a worker to
// come in. If it does not come in within the default status grace
// period, this function returns an error.
func (tc *TestController) WaitForNextWorkerStatusUpdate(workerId string) error {
tc.Logger().Debug("waiting for next status report from worker", "worker", workerId)
if err := tc.waitForNextWorkerStatusUpdate(workerId); err != nil {
tc.Logger().Error("error waiting for next status report from worker", "worker", workerId, "err", err)
return err
}
tc.Logger().Debug("waiting for next status report from worker received successfully", "worker", workerId)
return nil
}
func (tc *TestController) waitForNextWorkerStatusUpdate(workerId string) error {
waitStatusStart := time.Now()
ctx, cancel := context.WithTimeout(tc.ctx, tc.b.StatusGracePeriodDuration)
defer cancel()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
// pass
}
var waitStatusCurrent time.Time
var err error
tc.Controller().WorkerStatusUpdateTimes().Range(func(k, v interface{}) bool {
if k == nil || v == nil {
err = fmt.Errorf("nil key or value on entry: key=%#v value=%#v", k, v)
return false
}
workerStatusUpdateId, ok := k.(string)
if !ok {
err = fmt.Errorf("unexpected type %T for key: key=%#v value=%#v", k, k, v)
return false
}
workerStatusUpdateTime, ok := v.(time.Time)
if !ok {
err = fmt.Errorf("unexpected type %T for value: key=%#v value=%#v", k, k, v)
return false
}
if workerStatusUpdateId == workerId {
waitStatusCurrent = workerStatusUpdateTime
return false
}
return true
})
if err != nil {
return err
}
if waitStatusCurrent.Sub(waitStatusStart) > 0 {
break
}
}
return nil
}

@ -37,8 +37,9 @@ func WithLimit(limit int) Option {
}
}
// WithSkipVetForWrite provides an option to allow skipping vet checks to allow
// testing lower-level SQL triggers and constraints
// WithLiveness indicates how far back we want to search for server entries.
// Use 0 for the default liveness (15 seconds). A liveness value of -1 removes
// the liveliness condition.
func WithLiveness(liveness time.Duration) Option {
return func(o *options) {
o.withLiveness = liveness

@ -13,7 +13,15 @@ import (
)
const (
defaultLiveness = 15 * time.Second
// DefaultLiveness is the setting that controls the server "liveness" time,
// or the maximum allowable time that a worker can't send a status update to
// the controller for. After this, the server is considered dead, and it will
// be taken out of the rotation for allowable workers for connections, and
// connections will possibly start to be terminated and marked as closed
// depending on the grace period setting (see
// base.Server.StatusGracePeriodDuration). This value serves as the default
// and minimum allowable setting for the grace period.
DefaultLiveness = 15 * time.Second
)
type ServerType string
@ -67,18 +75,27 @@ func (r *Repository) listServersWithReader(ctx context.Context, reader db.Reader
opts := getOpts(opt...)
liveness := opts.withLiveness
if liveness == 0 {
liveness = defaultLiveness
liveness = DefaultLiveness
}
var where string
if liveness > 0 {
where = fmt.Sprintf("type = $1 and update_time > now() - interval '%d seconds'", uint32(liveness.Seconds()))
} else {
where = "type = $1"
}
var servers []*Server
if err := reader.SearchWhere(
ctx,
&servers,
fmt.Sprintf("type = $1 and update_time > now() - interval '%d seconds'", uint32(liveness.Seconds())),
where,
[]interface{}{serverType},
db.WithLimit(-1),
); err != nil {
return nil, errors.Wrap(err, "servers.listServersWithReader")
}
return servers, nil
}

@ -1,12 +1,14 @@
package servers_test
import (
"context"
"testing"
"time"
"github.com/hashicorp/boundary/api/roles"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/servers"
"github.com/hashicorp/boundary/internal/servers/controller"
"github.com/hashicorp/boundary/internal/types/scope"
@ -180,3 +182,80 @@ func TestTagUpdatingListing(t *testing.T) {
}
require.Equal(exp, tags)
}
func TestListServersWithLiveness(t *testing.T) {
t.Parallel()
require := require.New(t)
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
serversRepo, err := servers.NewRepository(rw, rw, kms)
require.NoError(err)
ctx := context.Background()
newServer := func(privateId string) *servers.Server {
result := &servers.Server{
PrivateId: privateId,
Type: "worker",
Address: "127.0.0.1",
}
_, rowsUpdated, err := serversRepo.UpsertServer(ctx, result)
require.NoError(err)
require.Equal(1, rowsUpdated)
return result
}
server1 := newServer("test1")
server2 := newServer("test2")
server3 := newServer("test3")
// Sleep the default liveness time (15sec currently) +1s
time.Sleep(time.Second * 16)
// Push an upsert to the first worker so that its status has been
// updated.
_, rowsUpdated, err := serversRepo.UpsertServer(ctx, server1)
require.NoError(err)
require.Equal(1, rowsUpdated)
requireIds := func(expected []string, actual []*servers.Server) {
require.Len(expected, len(actual))
want := make(map[string]struct{})
for _, v := range expected {
want[v] = struct{}{}
}
got := make(map[string]struct{})
for _, v := range actual {
got[v.PrivateId] = struct{}{}
}
require.Equal(want, got)
}
// Default liveness, should only list 1
result, err := serversRepo.ListServers(ctx, servers.ServerTypeWorker)
require.NoError(err)
require.Len(result, 1)
requireIds([]string{server1.PrivateId}, result)
// Upsert second server.
_, rowsUpdated, err = serversRepo.UpsertServer(ctx, server2)
require.NoError(err)
require.Equal(1, rowsUpdated)
// Static liveness. Should get two, so long as this did not take
// more than 5s to execute.
result, err = serversRepo.ListServers(ctx, servers.ServerTypeWorker, servers.WithLiveness(time.Second*5))
require.NoError(err)
require.Len(result, 2)
requireIds([]string{server1.PrivateId, server2.PrivateId}, result)
// Liveness disabled, should get all three servers.
result, err = serversRepo.ListServers(ctx, servers.ServerTypeWorker, servers.WithLiveness(-1))
require.NoError(err)
require.Len(result, 3)
requireIds([]string{server1.PrivateId, server2.PrivateId, server3.PrivateId}, result)
}

@ -3,8 +3,6 @@ package worker
import (
"context"
"math/rand"
"os"
"strconv"
"time"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
@ -15,48 +13,10 @@ import (
// In the future we could make this configurable
const (
statusInterval = 2 * time.Second
statusTimeout = 5 * time.Second
defaultStatusGracePeriod = 30 * time.Second
statusGracePeriodEnvVar = "BOUNDARY_STATUS_GRACE_PERIOD"
statusInterval = 2 * time.Second
statusTimeout = 5 * time.Second
)
// statusGracePeriod returns the status grace period setting for this
// worker, in seconds.
//
// The grace period is the length of time we allow connections to run
// on a worker in the event of an error sending status updates. The
// period is defined the length of time since the last successful
// update.
//
// The setting is derived from one of the following:
//
// * BOUNDARY_STATUS_GRACE_PERIOD, if defined, can be set to an
// integer value to define the setting.
// * If this is missing, the default (30 seconds) is used.
//
func (w *Worker) statusGracePeriod() time.Duration {
if v := os.Getenv(statusGracePeriodEnvVar); v != "" {
n, err := strconv.Atoi(v)
if err != nil {
w.logger.Error("could not read setting for BOUNDARY_STATUS_GRACE_PERIOD, using default",
"err", err,
"value", v,
)
return defaultStatusGracePeriod
}
if n < 1 {
w.logger.Error("invalid setting for BOUNDARY_STATUS_GRACE_PERIOD, using default", "value", v)
return defaultStatusGracePeriod
}
return time.Second * time.Duration(n)
}
return defaultStatusGracePeriod
}
type LastStatusInformation struct {
*pbs.StatusResponse
StatusTime time.Time
@ -99,6 +59,37 @@ func (w *Worker) LastStatusSuccess() *LastStatusInformation {
return w.lastStatusSuccess.Load().(*LastStatusInformation)
}
// WaitForNextSuccessfulStatusUpdate waits for the next successful status. It's
// used by testing (and in the future, shutdown) in place of a more opaque and
// possibly unnecessarily long sleep for things like initial controller
// check-in, etc.
//
// The timeout is aligned with the worker's status grace period. A nil error
// means the status was sent successfully.
func (w *Worker) WaitForNextSuccessfulStatusUpdate() error {
w.logger.Debug("waiting for next status report to controller")
waitStatusStart := time.Now()
ctx, cancel := context.WithTimeout(w.baseContext, w.conf.StatusGracePeriodDuration)
defer cancel()
for {
select {
case <-time.After(time.Second):
// pass
case <-ctx.Done():
w.logger.Error("error waiting for next status report to controller", "err", ctx.Err())
return ctx.Err()
}
if w.lastSuccessfulStatusTime().Sub(waitStatusStart) > 0 {
break
}
}
w.logger.Debug("next worker status update sent successfully")
return nil
}
func (w *Worker) sendWorkerStatus(cancelCtx context.Context) {
// First send info as-is. We'll perform cleanup duties after we
// get cancel/job change info back.
@ -196,6 +187,7 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context) {
w.lastStatusSuccess.Store(&LastStatusInformation{StatusResponse: result, StatusTime: time.Now()})
for _, request := range result.GetJobsRequests() {
w.logger.Trace("got job request from controller", "request", request)
switch request.GetRequestType() {
case pbs.CHANGETYPE_CHANGETYPE_UPDATE_STATE:
switch request.GetJob().GetType() {
@ -204,12 +196,25 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context) {
sessionId := sessInfo.GetSessionId()
siRaw, ok := w.sessionInfoMap.Load(sessionId)
if !ok {
w.logger.Warn("asked to cancel session but could not find a local information for it", "session_id", sessionId)
w.logger.Warn("session change requested but could not find local information for it", "session_id", sessionId)
continue
}
si := siRaw.(*sessionInfo)
si.Lock()
si.status = sessInfo.GetStatus()
// Update connection state if there are any connections in
// the request.
for _, conn := range sessInfo.GetConnections() {
connId := conn.GetConnectionId()
connInfo, ok := si.connInfoMap[conn.GetConnectionId()]
if !ok {
w.logger.Warn("connection change requested but could not find local information for it", "connection_id", connId)
continue
}
connInfo.status = conn.GetStatus()
}
si.Unlock()
}
}
@ -240,22 +245,31 @@ func (w *Worker) cleanupConnections(cancelCtx context.Context, ignoreSessionStat
si.status == pbs.SESSIONSTATUS_SESSIONSTATUS_CANCELING,
si.status == pbs.SESSIONSTATUS_SESSIONSTATUS_TERMINATED,
time.Until(si.lookupSessionResponse.Expiration.AsTime()) < 0:
var toClose int
for k, v := range si.connInfoMap {
if v.closeTime.IsZero() {
toClose++
v.connCancel()
w.logger.Info("terminated connection due to cancellation or expiration", "session_id", si.id, "connection_id", k)
closeInfo[k] = si.id
}
// Cancel connections without regard to individual connection
// state.
closedIds := w.cancelConnections(si.connInfoMap, true)
for _, connId := range closedIds {
closeInfo[connId] = si.id
w.logClose(si.id, connId)
}
// closeTime is marked by closeConnections iff the
// status is returned for that connection as closed. If
// the session is no longer valid and all connections
// are marked closed, clean up the session.
if toClose == 0 {
if len(closedIds) == 0 {
cleanSessionIds = append(cleanSessionIds, si.id)
}
default:
// Cancel connections *with* regard to individual connection
// state (ie: only ones that the controller has requested be
// terminated).
closedIds := w.cancelConnections(si.connInfoMap, false)
for _, connId := range closedIds {
closeInfo[connId] = si.id
w.logClose(si.id, connId)
}
}
return true
@ -277,6 +291,33 @@ func (w *Worker) cleanupConnections(cancelCtx context.Context, ignoreSessionStat
}
}
// cancelConnections is run by cleanupConnections to iterate over a
// session's connInfoMap and close connections based on the
// connection's state (or regardless if ignoreConnectionState is
// set).
//
// The returned map and slice are the maps of connection -> session,
// and sessions to completely remove from local state, respectively.
func (w *Worker) cancelConnections(connInfoMap map[string]*connInfo, ignoreConnectionState bool) []string {
var closedIds []string
for k, v := range connInfoMap {
if v.closeTime.IsZero() {
if !ignoreConnectionState && v.status != pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED {
continue
}
v.connCancel()
closedIds = append(closedIds, k)
}
}
return closedIds
}
func (w *Worker) logClose(sessionId, connId string) {
w.logger.Info("terminated connection due to cancellation or expiration", "session_id", sessionId, "connection_id", connId)
}
func (w *Worker) lastSuccessfulStatusTime() time.Time {
lastStatus := w.LastStatusSuccess()
if lastStatus == nil {
@ -288,7 +329,7 @@ func (w *Worker) lastSuccessfulStatusTime() time.Time {
func (w *Worker) isPastGrace() (bool, time.Time, time.Duration) {
t := w.lastSuccessfulStatusTime()
u := w.statusGracePeriod()
u := w.conf.StatusGracePeriodDuration
v := time.Since(t)
return v > u, t, u
}

@ -0,0 +1,56 @@
package worker
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require"
)
func TestWorkerWaitForNextSuccessfulStatusUpdate(t *testing.T) {
t.Parallel()
require := require.New(t)
for _, name := range []string{"ok", "timeout"} {
t.Run(name, func(t *testing.T) {
// As-needed initialization of a mock worker
w := &Worker{
logger: hclog.New(nil),
lastStatusSuccess: new(atomic.Value),
baseContext: context.Background(),
conf: &Config{
Server: &base.Server{
StatusGracePeriodDuration: time.Second * 1,
},
},
}
// This is present in New()
w.lastStatusSuccess.Store((*LastStatusInformation)(nil))
var wg sync.WaitGroup
var err error
wg.Add(1)
go func() {
err = w.WaitForNextSuccessfulStatusUpdate()
wg.Done()
}()
if name == "ok" {
time.Sleep(time.Millisecond * 100)
w.lastStatusSuccess.Store(&LastStatusInformation{StatusTime: time.Now()})
}
wg.Wait()
if name == "timeout" {
require.ErrorIs(err, context.DeadlineExceeded)
} else {
require.NoError(err)
}
})
}
}

@ -176,6 +176,10 @@ type TestWorkerOpts struct {
// The logger to use, or one will be created
Logger hclog.Logger
// The amount of time to wait before marking connections as closed when a
// connection cannot be made back to the controller
StatusGracePeriodDuration time.Duration
}
func NewTestWorker(t *testing.T, opts *TestWorkerOpts) *TestWorker {
@ -219,6 +223,9 @@ func NewTestWorker(t *testing.T, opts *TestWorkerOpts) *TestWorker {
})
}
// Initialize status grace period
tw.b.SetStatusGracePeriodDuration(opts.StatusGracePeriodDuration)
if opts.Config.Worker == nil {
opts.Config.Worker = new(config.Worker)
}
@ -278,10 +285,11 @@ func (tw *TestWorker) AddClusterWorkerMember(t *testing.T, opts *TestWorkerOpts)
opts = new(TestWorkerOpts)
}
nextOpts := &TestWorkerOpts{
WorkerAuthKms: tw.w.conf.WorkerAuthKms,
Name: opts.Name,
InitialControllers: tw.ControllerAddrs(),
Logger: tw.w.conf.Logger,
WorkerAuthKms: tw.w.conf.WorkerAuthKms,
Name: opts.Name,
InitialControllers: tw.ControllerAddrs(),
Logger: tw.w.conf.Logger,
StatusGracePeriodDuration: opts.StatusGracePeriodDuration,
}
if opts.Logger != nil {
nextOpts.Logger = opts.Logger

@ -28,8 +28,8 @@ type Worker struct {
started *ua.Bool
controllerStatusConn *atomic.Value
workerStartTime time.Time
lastStatusSuccess *atomic.Value
workerStartTime time.Time
controllerResolver *atomic.Value

@ -235,13 +235,13 @@ where
)
`
// connectionsToCloseCte finds connections that are:
// closeDeadConnectionsCte finds connections that are:
//
// * not closed
// * not announced by a given server in its latest update
//
// and marks them as closed.
connectionsToCloseCte = `
closeDeadConnectionsCte = `
with
-- Find connections that are not closed so we can reference those IDs
unclosed_connections as (
@ -257,7 +257,7 @@ with
-- It's not in limbo between when it moved into this state and when
-- it started being reported by the worker, which is roughly every
-- 2-3 seconds
start_time < now() - interval '10 seconds'
start_time < wt_sub_seconds_from_now(10)
),
connections_to_close as (
select public_id
@ -278,5 +278,100 @@ with
where
public_id in (select public_id from connections_to_close)
returning public_id
`
// closeConnectionsForDeadServersCte finds connections that are:
//
// * not closed
// * belong to servers that have not reported in within an acceptable
// threshold of time
//
// and marks them as closed.
//
// The query returns the set of servers that have had connections closed
// along with their last update time and the number of connections closed on
// each.
closeConnectionsForDeadServersCte = `
with
-- Get dead servers, parameterized off of grace period in seconds
dead_servers as (
select private_id, update_time
from server
where update_time < wt_sub_seconds_from_now($1)
),
-- Find connections that are not closed so we can reference those IDs
unclosed_connections as (
select connection_id
from session_connection_state
where
-- It's the current state
end_time is null
and
-- Current state isn't closed state
state in ('authorized', 'connected')
and
-- It's not in limbo between when it moved into this state and when
-- it started being reported by the worker, which is roughly every
-- 2-3 seconds
start_time < wt_sub_seconds_from_now(10)
),
connections_to_close as (
select public_id
from session_connection
where
-- Related to the worker that just reported to us
server_id in (select private_id from dead_servers)
and
-- Only unclosed ones
public_id in (select connection_id from unclosed_connections)
),
closed_connections as (
update session_connection
set
closed_reason = 'system error'
where
public_id in (select public_id from connections_to_close)
returning public_id, server_id
)
select
dead_servers.private_id,
dead_servers.update_time,
count(closed_connections.public_id)
from dead_servers
left join closed_connections
on dead_servers.private_id = closed_connections.server_id
group by dead_servers.private_id, dead_servers.update_time
having count(closed_connections.public_id) > 0
order by dead_servers.private_id
`
// shouldCloseConnectionsCte finds connections that are marked as closed in
// the database given a set of connection IDs. They are returned along with
// their associated session ID.
//
// The second parameter is a set of session IDs that we have already
// submitted a session-wide close request for, so sending another change
// request for them would be redundant.
shouldCloseConnectionsCte = `
with
-- Find connections that are closed so we can reference those IDs
closed_connections as (
select connection_id
from session_connection_state
where
-- It's the current state
end_time is null
and
-- Current state is closed
state = 'closed'
and
connection_id in (%s)
)
select public_id, session_id
from session_connection
where
public_id in (select connection_id from closed_connections)
-- Below fmt arg is filled in if there are session IDs to filter against
%s
`
)

@ -4,11 +4,18 @@ import (
"context"
"fmt"
"strings"
"time"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/servers"
)
// deadWorkerConnCloseMinGrace is the minimum allowable setting for
// the CloseConnectionsForDeadWorkers method. This is synced with
// the default server liveness setting.
var DeadWorkerConnCloseMinGrace = int(servers.DefaultLiveness.Seconds())
// LookupConnection will look up a connection in the repository and return the connection
// with its states. If the connection is not found, it will return nil, nil, nil.
// No options are currently supported.
@ -99,18 +106,16 @@ func (r *Repository) DeleteConnection(ctx context.Context, publicId string, _ ..
return rowsDeleted, nil
}
// CloseDeadConnectionsOnWorkerReport will run the connectionsToClose CTE to
// look for connections that should be marked closed because they are no longer
// claimed by a server. This does notdetect connections where the server is no
// longer reporting status; that's for a different CTE to be called by a
// heartbeat detector.
// CloseDeadConnectionsForWorker will run closeDeadConnectionsCte to look for
// connections that should be marked closed because they are no longer claimed
// by a server.
//
// The foundConns input should be the currently-claimed connections; the CTE
// uses a NOT IN clause to ensure these are excluded. It is not an error for
// this to be empty as the worker could claim no connections; in that case all
// connections will immediately transition to closed.
func (r *Repository) CloseDeadConnectionsOnWorkerReport(ctx context.Context, serverId string, foundConns []string) (int, error) {
const op = "session.(Repository).CloseDeadConnectionsOnWorkerReport"
func (r *Repository) CloseDeadConnectionsForWorker(ctx context.Context, serverId string, foundConns []string) (int, error) {
const op = "session.(Repository).CloseDeadConnectionsForWorker"
if serverId == "" {
return db.NoRowsAffected, errors.New(errors.InvalidParameter, op, "missing server id")
}
@ -135,7 +140,7 @@ func (r *Repository) CloseDeadConnectionsOnWorkerReport(ctx context.Context, ser
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
var err error
rowsAffected, err = w.Exec(ctx, fmt.Sprintf(connectionsToCloseCte, publicIdStr), args)
rowsAffected, err = w.Exec(ctx, fmt.Sprintf(closeDeadConnectionsCte, publicIdStr), args)
if err != nil {
return errors.Wrap(err, op)
}
@ -148,6 +153,129 @@ func (r *Repository) CloseDeadConnectionsOnWorkerReport(ctx context.Context, ser
return rowsAffected, nil
}
type CloseConnectionsForDeadWorkersResult struct {
ServerId string
LastUpdateTime time.Time
NumberConnectionsClosed int
}
// CloseConnectionsForDeadWorkers will run
// closeConnectionsForDeadServersCte to look for connections that
// should be marked because they are on a server that is no longer
// sending status updates to the controller(s).
//
// The only input to the method is the grace period, in seconds.
func (r *Repository) CloseConnectionsForDeadWorkers(ctx context.Context, gracePeriod int) ([]CloseConnectionsForDeadWorkersResult, error) {
const op = "session.(Repository).CloseConnectionsForDeadWorkers"
if gracePeriod < DeadWorkerConnCloseMinGrace {
return nil, errors.New(
errors.InvalidParameter, op, fmt.Sprintf("gracePeriod must be at least %d seconds", DeadWorkerConnCloseMinGrace))
}
results := make([]CloseConnectionsForDeadWorkersResult, 0)
_, err := r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
rows, err := w.Query(ctx, closeConnectionsForDeadServersCte, []interface{}{gracePeriod})
if err != nil {
return errors.Wrap(err, op)
}
defer rows.Close()
for rows.Next() {
var (
serverId string
lastUpdateTime time.Time
numberConnectionsClosed int
)
if err := rows.Scan(&serverId, &lastUpdateTime, &numberConnectionsClosed); err != nil {
return errors.Wrap(err, op)
}
results = append(results, CloseConnectionsForDeadWorkersResult{
ServerId: serverId,
LastUpdateTime: lastUpdateTime,
NumberConnectionsClosed: numberConnectionsClosed,
})
}
return nil
},
)
if err != nil {
return nil, errors.Wrap(err, op)
}
return results, nil
}
// ShouldCloseConnectionsOnWorker will run shouldCloseConnectionsCte to look
// for connections that the worker should close because they are currently
// reporting them as open incorrectly.
//
// The foundConns input here is used to filter closed connection states. This
// is further filtered against the filterSessions input, which is expected to
// be a set of sessions we've already submitted close requests for, so adding
// them again would be redundant.
//
// The returned map[string][]string is indexed by session ID.
func (r *Repository) ShouldCloseConnectionsOnWorker(ctx context.Context, foundConns, filterSessions []string) (map[string][]string, error) {
const op = "session.(Repository).ShouldCloseConnectionsOnWorker"
if len(foundConns) < 1 {
return nil, nil // nothing to do
}
args := make([]interface{}, 0, len(foundConns)+len(filterSessions))
// foundConns first
connsParams := make([]string, len(foundConns))
for i, connId := range foundConns {
connsParams[i] = fmt.Sprintf("$%d", i+1)
args = append(args, connId)
}
connsStr := strings.Join(connsParams, ",")
// then filterSessions
var sessionsStr string
if len(filterSessions) > 0 {
offset := len(foundConns) + 1
sessionsParams := make([]string, len(filterSessions))
for i, sessionId := range filterSessions {
sessionsParams[i] = fmt.Sprintf("$%d", i+offset)
args = append(args, sessionId)
}
const sessionIdFmtStr = `and session_id not in (%s)`
sessionsStr = fmt.Sprintf(sessionIdFmtStr, strings.Join(sessionsParams, ","))
}
rows, err := r.reader.Query(
ctx,
fmt.Sprintf(shouldCloseConnectionsCte, connsStr, sessionsStr),
args,
)
if err != nil {
return nil, errors.Wrap(err, op)
}
defer rows.Close()
result := make(map[string][]string)
for rows.Next() {
var connectionId, sessionId string
if err := rows.Scan(&connectionId, &sessionId); err != nil {
return nil, errors.Wrap(err, op)
}
result[sessionId] = append(result[sessionId], connectionId)
}
return result, nil
}
func fetchConnectionStates(ctx context.Context, r db.Reader, connectionId string, opt ...db.Option) ([]*ConnectionState, error) {
const op = "session.fetchConnectionStates"
var states []*ConnectionState

@ -2,6 +2,7 @@ package session
import (
"context"
"fmt"
"testing"
"time"
@ -10,6 +11,7 @@ import (
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/hashicorp/boundary/internal/servers"
"github.com/hashicorp/shared-secure-libs/strutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -210,7 +212,7 @@ func TestRepository_DeleteConnection(t *testing.T) {
}
}
func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) {
func TestRepository_CloseDeadConnectionsOnWorker(t *testing.T) {
t.Parallel()
require, assert := require.New(t), assert.New(t)
conn, _ := db.TestSetup(t, "postgres")
@ -284,7 +286,7 @@ func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) {
// all connection IDs for worker 1 should be showing as non-closed, and
// the ones for worker 2 not advertised should be closed.
shouldStayOpen := worker2ConnIds[0:2]
count, err := repo.CloseDeadConnectionsOnWorkerReport(ctx, worker2.GetPrivateId(), shouldStayOpen)
count, err := repo.CloseDeadConnectionsForWorker(ctx, worker2.GetPrivateId(), shouldStayOpen)
require.NoError(err)
assert.Equal(4, count)
@ -310,7 +312,7 @@ func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) {
// Now, advertise none of the connection IDs for worker 2. This is mainly to
// test that handling the case where we do not include IDs works properly as
// it changes the where clause.
count, err = repo.CloseDeadConnectionsOnWorkerReport(ctx, worker1.GetPrivateId(), nil)
count, err = repo.CloseDeadConnectionsForWorker(ctx, worker1.GetPrivateId(), nil)
require.NoError(err)
assert.Equal(6, count)
@ -330,3 +332,440 @@ func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) {
assert.True(foundClosed != strutil.StrListContains(shouldStayOpen, conn.PublicId))
}
}
func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) {
t.Parallel()
require := require.New(t)
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(rw, rw, kms)
require.NoError(err)
serversRepo, err := servers.NewRepository(rw, rw, kms)
require.NoError(err)
ctx := context.Background()
// connection count = 6 * states(authorized, connected, closed = 3) * servers_with_open_connections(3)
numConns := 54
// Create four "workers". This is similar to the setup in
// TestRepository_CloseDeadConnectionsOnWorker, but a bit more complex;
// firstly, the last worker will have no connections at all, and we will be
// closing the others in stages to test multiple servers being closed at
// once.
worker1 := TestWorker(t, conn, wrapper, WithServerId("worker1"))
worker2 := TestWorker(t, conn, wrapper, WithServerId("worker2"))
worker3 := TestWorker(t, conn, wrapper, WithServerId("worker3"))
worker4 := TestWorker(t, conn, wrapper, WithServerId("worker4"))
// Create sessions on the first three, activate, and authorize connections
var worker1ConnIds, worker2ConnIds, worker3ConnIds []string
for i := 0; i < numConns; i++ {
var serverId string
if i%3 == 0 {
serverId = worker1.PrivateId
} else if i%3 == 1 {
serverId = worker2.PrivateId
} else {
serverId = worker3.PrivateId
}
sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithServerId(serverId), WithDbOpts(db.WithSkipVetForWrite(true)))
sess, _, err = repo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, serverId, "worker", []byte("foo"))
require.NoError(err)
c, cs, _, err := repo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId)
require.NoError(err)
require.Len(cs, 1)
require.Equal(StatusAuthorized, cs[0].Status)
if i%3 == 0 {
worker1ConnIds = append(worker1ConnIds, c.GetPublicId())
} else if i%3 == 1 {
worker2ConnIds = append(worker2ConnIds, c.GetPublicId())
} else {
worker3ConnIds = append(worker3ConnIds, c.GetPublicId())
}
}
// Mark a third of the connections connected, a third closed, and leave the
// others authorized. This is just to ensure we have a spread when we test it
// out.
for i, connId := range func() []string {
var s []string
s = append(s, worker1ConnIds...)
s = append(s, worker2ConnIds...)
s = append(s, worker3ConnIds...)
return s
}() {
if i%3 == 0 {
_, cs, err := repo.ConnectConnection(ctx, ConnectWith{
ConnectionId: connId,
ClientTcpAddress: "127.0.0.1",
ClientTcpPort: 22,
EndpointTcpAddress: "127.0.0.1",
EndpointTcpPort: 22,
})
require.NoError(err)
require.Len(cs, 2)
var foundAuthorized, foundConnected bool
for _, status := range cs {
if status.Status == StatusAuthorized {
foundAuthorized = true
}
if status.Status == StatusConnected {
foundConnected = true
}
}
require.True(foundAuthorized)
require.True(foundConnected)
} else if i%3 == 1 {
resp, err := repo.CloseConnections(ctx, []CloseWith{
{
ConnectionId: connId,
ClosedReason: ConnectionCanceled,
},
})
require.NoError(err)
require.Len(resp, 1)
cs := resp[0].ConnectionStates
require.Len(cs, 2)
var foundAuthorized, foundClosed bool
for _, status := range cs {
if status.Status == StatusAuthorized {
foundAuthorized = true
}
if status.Status == StatusClosed {
foundClosed = true
}
}
require.True(foundAuthorized)
require.True(foundClosed)
}
}
// There is a 15 second delay to account for time for the connections to
// transition
time.Sleep(15 * time.Second)
// updateServer is a helper for updating the update time for our
// servers. The controller is read back so that we can reference
// the most up-to-date fields.
updateServer := func(t *testing.T, w *servers.Server) *servers.Server {
t.Helper()
_, rowsUpdated, err := serversRepo.UpsertServer(ctx, w)
require.NoError(err)
require.Equal(1, rowsUpdated)
servers, err := serversRepo.ListServers(ctx, servers.ServerTypeWorker)
require.NoError(err)
for _, server := range servers {
if server.PrivateId == w.PrivateId {
return server
}
}
require.FailNowf("server %q not found after updating", w.PrivateId)
// Looks weird but needed to build, as we fail in testify instead
// of returning an error
return nil
}
// requireConnectionStatus is a helper expecting all connections on a worker
// to be closed.
requireConnectionStatus := func(t *testing.T, connIds []string, expectAllClosed bool) {
t.Helper()
var conns []*Connection
require.NoError(repo.list(ctx, &conns, "", nil))
for i, connId := range connIds {
var expected ConnectionStatus
switch {
case expectAllClosed:
expected = StatusClosed
case i%3 == 0:
expected = StatusConnected
case i%3 == 1:
expected = StatusClosed
case i%3 == 2:
expected = StatusAuthorized
}
_, states, err := repo.LookupConnection(ctx, connId)
require.NoError(err)
require.Equal(expected, states[0].Status, "expected latest status for %q (index %d) to be %v", connId, i, expected)
}
}
// We need this helper to fix the zone on protobuf timestamps
// versus what gets reported in the
// CloseConnectionsForDeadWorkersResult.
timestampPbAsUTC := func(t *testing.T, tm time.Time) time.Time {
t.Helper()
utcLoc, err := time.LoadLocation("Etc/UTC")
require.NoError(err)
return tm.In(utcLoc)
}
// Now try some scenarios.
{
// First, test the error/validation case.
result, err := repo.CloseConnectionsForDeadWorkers(ctx, 0)
require.Equal(err, errors.E(
errors.WithCode(errors.InvalidParameter),
errors.WithOp("session.(Repository).CloseConnectionsForDeadWorkers"),
errors.WithMsg(fmt.Sprintf("gracePeriod must be at least %d seconds", DeadWorkerConnCloseMinGrace)),
))
require.Nil(result)
}
{
// Now, try the basis, or where all workers are reporting in.
worker1 = updateServer(t, worker1)
worker2 = updateServer(t, worker2)
worker3 = updateServer(t, worker3)
updateServer(t, worker4) // no re-assignment here because we never reference the server again
result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace)
require.NoError(err)
require.Empty(result)
// Expect appropriate split connection state on worker1
requireConnectionStatus(t, worker1ConnIds, false)
// Expect appropriate split connection state on worker2
requireConnectionStatus(t, worker2ConnIds, false)
// Expect appropriate split connection state on worker3
requireConnectionStatus(t, worker3ConnIds, false)
}
{
// Now try a zero case - similar to the basis, but only in that no results
// are expected to be returned for workers with no connections, even if
// they are dead. Here, the server with no connections is worker #4.
time.Sleep(time.Second * time.Duration(DeadWorkerConnCloseMinGrace))
worker1 = updateServer(t, worker1)
worker2 = updateServer(t, worker2)
worker3 = updateServer(t, worker3)
result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace)
require.NoError(err)
require.Empty(result)
// Expect appropriate split connection state on worker1
requireConnectionStatus(t, worker1ConnIds, false)
// Expect appropriate split connection state on worker2
requireConnectionStatus(t, worker2ConnIds, false)
// Expect appropriate split connection state on worker3
requireConnectionStatus(t, worker3ConnIds, false)
}
{
// The first induction is letting the first worker "die" by not updating it
// too. All of its authorized and connected connections should be dead.
time.Sleep(time.Second * time.Duration(DeadWorkerConnCloseMinGrace))
worker2 = updateServer(t, worker2)
worker3 = updateServer(t, worker3)
result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace)
require.NoError(err)
// Assert that we have one result with the appropriate ID and
// number of connections closed. Due to how things are
require.Equal([]CloseConnectionsForDeadWorkersResult{
{
ServerId: worker1.PrivateId,
LastUpdateTime: timestampPbAsUTC(t, worker1.UpdateTime.AsTime()),
NumberConnectionsClosed: 12, // 18 per server, with 6 closed already
},
}, result)
// Expect all connections closed on worker1
requireConnectionStatus(t, worker1ConnIds, true)
// Expect appropriate split connection state on worker2
requireConnectionStatus(t, worker2ConnIds, false)
// Expect appropriate split connection state on worker3
requireConnectionStatus(t, worker3ConnIds, false)
}
{
// The final case is having the other two workers die. After
// this, we should have all connections closed with the
// appropriate message from the next two servers acted on.
time.Sleep(time.Second * time.Duration(DeadWorkerConnCloseMinGrace))
result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace)
require.NoError(err)
// Assert that we have one result with the appropriate ID and number of connections closed.
require.Equal([]CloseConnectionsForDeadWorkersResult{
{
ServerId: worker2.PrivateId,
LastUpdateTime: timestampPbAsUTC(t, worker2.UpdateTime.AsTime()),
NumberConnectionsClosed: 12, // 18 per server, with 6 closed already
},
{
ServerId: worker3.PrivateId,
LastUpdateTime: timestampPbAsUTC(t, worker3.UpdateTime.AsTime()),
NumberConnectionsClosed: 12, // 18 per server, with 6 closed already
},
}, result)
// Expect all connections closed on worker1
requireConnectionStatus(t, worker1ConnIds, true)
// Expect all connections closed on worker2
requireConnectionStatus(t, worker2ConnIds, true)
// Expect all connections closed on worker3
requireConnectionStatus(t, worker3ConnIds, true)
}
}
func TestRepository_ShouldCloseConnectionsOnWorker(t *testing.T) {
t.Parallel()
require := require.New(t)
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(rw, rw, kms)
require.NoError(err)
ctx := context.Background()
numConns := 12
// Create a worker, we only need one here as our query is dependent
// on connection and not worker.
worker1 := TestWorker(t, conn, wrapper, WithServerId("worker1"))
// Create a few sessions on each, activate, and authorize a connection
var connIds []string
sessionConnIds := make(map[string][]string)
for i := 0; i < numConns; i++ {
serverId := worker1.PrivateId
sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithServerId(serverId), WithDbOpts(db.WithSkipVetForWrite(true)))
sessionId := sess.GetPublicId()
sess, _, err = repo.ActivateSession(ctx, sessionId, sess.Version, serverId, "worker", []byte("foo"))
require.NoError(err)
c, cs, _, err := repo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId)
require.NoError(err)
require.Len(cs, 1)
require.Equal(StatusAuthorized, cs[0].Status)
connId := c.GetPublicId()
connIds = append(connIds, connId)
sessionConnIds[sessionId] = append(sessionConnIds[sessionId], connId)
}
// Mark half of the connections connected, close the other half.
for i, connId := range connIds {
if i%2 == 0 {
_, cs, err := repo.ConnectConnection(ctx, ConnectWith{
ConnectionId: connId,
ClientTcpAddress: "127.0.0.1",
ClientTcpPort: 22,
EndpointTcpAddress: "127.0.0.1",
EndpointTcpPort: 22,
})
require.NoError(err)
require.Len(cs, 2)
var foundAuthorized, foundConnected bool
for _, status := range cs {
if status.Status == StatusAuthorized {
foundAuthorized = true
}
if status.Status == StatusConnected {
foundConnected = true
}
}
require.True(foundAuthorized)
require.True(foundConnected)
} else {
resp, err := repo.CloseConnections(ctx, []CloseWith{
{
ConnectionId: connId,
ClosedReason: ConnectionCanceled,
},
})
require.NoError(err)
require.Len(resp, 1)
cs := resp[0].ConnectionStates
require.Len(cs, 2)
var foundAuthorized, foundClosed bool
for _, status := range cs {
if status.Status == StatusAuthorized {
foundAuthorized = true
}
if status.Status == StatusClosed {
foundClosed = true
}
}
require.True(foundAuthorized)
require.True(foundClosed)
}
}
// There is a 10 second delay to account for time for the connections to
// transition
time.Sleep(15 * time.Second)
// Now we try some scenarios.
{
// First test an empty set.
result, err := repo.ShouldCloseConnectionsOnWorker(ctx, nil, nil)
require.NoError(err)
require.Zero(result, "should be empty when no connections are supplied")
}
{
// Here we pass in all of our connections without a filter on
// session. This should return half of the connections - the ones
// that we marked as closed.
//
// Create a copy of our session map with the sessions that have
// closed connections taken out.
expectedSessionConnIds := make(map[string][]string)
for sessionId, connIds := range sessionConnIds {
for _, connId := range connIds {
if testIsConnectionClosed(ctx, t, repo, connId) {
expectedSessionConnIds[sessionId] = append(expectedSessionConnIds[sessionId], connId)
}
}
}
// Send query, use all connections w/o a filter on sessions.
actualSessionConnIds, err := repo.ShouldCloseConnectionsOnWorker(ctx, connIds, nil)
require.NoError(err)
require.Equal(expectedSessionConnIds, actualSessionConnIds)
}
{
// Finally, add a session filter. We do this by just alternating
// the session IDs we want to filter on.
expectedSessionConnIds := make(map[string][]string)
var filterSessionIds []string
var filterSession bool
for sessionId, connIds := range sessionConnIds {
for _, connId := range connIds {
if testIsConnectionClosed(ctx, t, repo, connId) {
if !filterSession {
expectedSessionConnIds[sessionId] = append(expectedSessionConnIds[sessionId], connId)
} else {
filterSessionIds = append(filterSessionIds, sessionId)
}
// Toggle filterSession here (instead of just outer session
// loop) so that we aren't just lining up on
// connected/disconnected connections.
filterSession = !filterSession
}
}
}
// Send query with the session filter.
actualSessionConnIds, err := repo.ShouldCloseConnectionsOnWorker(ctx, connIds, filterSessionIds)
require.NoError(err)
require.Equal(expectedSessionConnIds, actualSessionConnIds)
}
}
func testIsConnectionClosed(ctx context.Context, t *testing.T, repo *Repository, connId string) bool {
require := require.New(t)
_, states, err := repo.LookupConnection(ctx, connId)
require.NoError(err)
// Use first state as this LookupConnections returns ordered by
// start time, descending.
return states[0].Status == StatusClosed
}

@ -37,9 +37,9 @@ import (
)
const (
testSendRecvSendMax = 60
defaultGracePeriod = time.Second * 30
expectConnectionStateOnWorkerTimeout = defaultGracePeriod * 2
defaultGracePeriod = time.Second * 15
expectConnectionStateOnControllerTimeout = time.Minute * 2
expectConnectionStateOnWorkerTimeout = defaultGracePeriod * 2
// This is the interval that we check states on in the worker. It
// needs to be particularly granular to ensure that we allow for
@ -55,227 +55,414 @@ const (
expectConnectionStateOnWorkerInterval = time.Millisecond * 100
)
func TestWorkerSessionCleanup(t *testing.T) {
require := require.New(t)
logger := hclog.New(&hclog.LoggerOptions{
Level: hclog.Trace,
})
conf, err := config.DevController()
require.NoError(err)
pl, err := net.Listen("tcp", "localhost:0")
require.NoError(err)
c1 := controller.NewTestController(t, &controller.TestControllerOpts{
Config: conf,
InitialResourcesSuffix: "1234567890",
Logger: logger.Named("c1"),
PublicClusterAddr: pl.Addr().String(),
})
defer c1.Shutdown()
expectWorkers(t, c1)
// Wire up the testing proxies
require.Len(c1.ClusterAddrs(), 1)
proxy, err := dawdle.NewProxy("tcp", "", c1.ClusterAddrs()[0],
dawdle.WithListener(pl),
dawdle.WithRbufSize(512),
dawdle.WithWbufSize(512),
)
require.NoError(err)
defer proxy.Close()
require.NotEmpty(t, proxy.ListenerAddr())
w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{
WorkerAuthKms: c1.Config().WorkerAuthKms,
InitialControllers: []string{proxy.ListenerAddr()},
Logger: logger.Named("w1"),
})
defer w1.Shutdown()
// timeoutBurdenType details our "burden cases" for the session
// cleanup tests.
//
// There are 3 burden cases:
//
// * default: This case simulates normal default operation where both
// worker and controller generally are timing out connections at
// generally the same interval. In reality, this is not necessarily
// going to be the case, but it's hard to test individual cases when
// both settings are the same.
//
// * worker: This case assumes the worker is the source of truth for
// controller state. Here, the controller's grace period is
// increased to a high factor over the default to ensure that the
// worker is managing the lifecycle of a connection and will properly
// unclaim it closed once the connection resumes, ensuring the
// connection is marked as closed on the worker.
//
// * controller: Here, the controller is the one doing the work. The
// connection will be open on the worker until status checks resume
// from the worker. At this point, the controller will request the
// status change on the worker, physically closing the connection
// there.
type timeoutBurdenType string
time.Sleep(10 * time.Second)
expectWorkers(t, c1, w1)
const (
timeoutBurdenTypeDefault timeoutBurdenType = "default"
timeoutBurdenTypeWorker timeoutBurdenType = "worker"
timeoutBurdenTypeController timeoutBurdenType = "controller"
)
// Use an independent context for test things that take a context so
// that we aren't tied to any timeouts in the controller, etc. This
// can interfere with some of the test operations.
ctx := context.Background()
var timeoutBurdenCases = []timeoutBurdenType{timeoutBurdenTypeDefault, timeoutBurdenTypeWorker, timeoutBurdenTypeController}
// Connect target
client := c1.Client()
client.SetToken(c1.Token().Token)
tcl := targets.NewClient(client)
tgt, err := tcl.Read(ctx, "ttcp_1234567890")
require.NoError(err)
require.NotNil(tgt)
func controllerGracePeriod(ty timeoutBurdenType) time.Duration {
if ty == timeoutBurdenTypeWorker {
return defaultGracePeriod * 10
}
// Create test server, update default port on target
ts := newTestTcpServer(t, logger)
require.NotNil(t, ts)
defer ts.Close()
tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port()))
require.NoError(err)
require.NotNil(tgt)
return defaultGracePeriod
}
// Authorize and connect
sess := newTestSession(ctx, t, logger, tcl, "ttcp_1234567890")
sConn := sess.Connect(ctx, t, logger)
func workerGracePeriod(ty timeoutBurdenType) time.Duration {
if ty == timeoutBurdenTypeController {
return defaultGracePeriod * 10
}
// Run initial send/receive test, make sure things are working
sConn.TestSendRecvAll(t)
return defaultGracePeriod
}
// Kill the link
proxy.Pause()
// TestWorkerSessionCleanup is the main test for session cleanup, and
// dispatches to the individual subtests.
func TestWorkerSessionCleanup(t *testing.T) {
t.Parallel()
for _, burdenCase := range timeoutBurdenCases {
burdenCase := burdenCase
t.Run(string(burdenCase), func(t *testing.T) {
t.Parallel()
t.Run("single_controller", testWorkerSessionCleanupSingle(burdenCase))
t.Run("multi_controller", testWorkerSessionCleanupMulti(burdenCase))
})
}
}
// Run again, ensure connection is dead
sConn.TestSendRecvFail(t)
func testWorkerSessionCleanupSingle(burdenCase timeoutBurdenType) func(t *testing.T) {
return func(t *testing.T) {
t.Parallel()
require := require.New(t)
logger := hclog.New(&hclog.LoggerOptions{
Name: t.Name(),
Level: hclog.Trace,
})
conf, err := config.DevController()
require.NoError(err)
pl, err := net.Listen("tcp", "localhost:0")
require.NoError(err)
c1 := controller.NewTestController(t, &controller.TestControllerOpts{
Config: conf,
InitialResourcesSuffix: "1234567890",
Logger: logger.Named("c1"),
PublicClusterAddr: pl.Addr().String(),
StatusGracePeriodDuration: controllerGracePeriod(burdenCase),
})
defer c1.Shutdown()
expectWorkers(t, c1)
// Wire up the testing proxies
require.Len(c1.ClusterAddrs(), 1)
proxy, err := dawdle.NewProxy("tcp", "", c1.ClusterAddrs()[0],
dawdle.WithListener(pl),
dawdle.WithRbufSize(256),
dawdle.WithWbufSize(256),
)
require.NoError(err)
defer proxy.Close()
require.NotEmpty(t, proxy.ListenerAddr())
w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{
WorkerAuthKms: c1.Config().WorkerAuthKms,
InitialControllers: []string{proxy.ListenerAddr()},
Logger: logger.Named("w1"),
StatusGracePeriodDuration: workerGracePeriod(burdenCase),
})
defer w1.Shutdown()
err = w1.Worker().WaitForNextSuccessfulStatusUpdate()
require.NoError(err)
err = c1.WaitForNextWorkerStatusUpdate(w1.Name())
require.NoError(err)
expectWorkers(t, c1, w1)
// Use an independent context for test things that take a context so
// that we aren't tied to any timeouts in the controller, etc. This
// can interfere with some of the test operations.
ctx := context.Background()
// Connect target
client := c1.Client()
client.SetToken(c1.Token().Token)
tcl := targets.NewClient(client)
tgt, err := tcl.Read(ctx, "ttcp_1234567890")
require.NoError(err)
require.NotNil(tgt)
// Create test server, update default port on target
ts := newTestTcpServer(t, logger)
require.NotNil(t, ts)
defer ts.Close()
tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port()), targets.WithSessionConnectionLimit(-1))
require.NoError(err)
require.NotNil(tgt)
// Authorize and connect
sess := newTestSession(ctx, t, logger, tcl, "ttcp_1234567890")
sConn := sess.Connect(ctx, t)
// Run initial send/receive test, make sure things are working
logger.Debug("running initial send/recv test")
sConn.TestSendRecvAll(t)
// Kill the link
logger.Debug("pausing controller/worker link")
proxy.Pause()
// Wait for failure connection state (depends on burden case)
switch burdenCase {
case timeoutBurdenTypeWorker:
// Wait on worker, then check controller
sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed)
sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusConnected)
case timeoutBurdenTypeController:
// Wait on controller, then check worker
sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed)
sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusConnected)
default:
// Should be closed on both worker and controller. Wait on
// worker then check controller.
sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed)
sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed)
}
// Assert we have no connections left (should be default behavior)
sess.TestNoConnectionsLeft(t)
// Run send/receive test again to check expected connection-level
// behavior
if burdenCase == timeoutBurdenTypeController {
// Burden on controller, should be successful until connection
// resumes
sConn.TestSendRecvAll(t)
} else {
// Connection should die in other cases
sConn.TestSendRecvFail(t)
}
// Assert connection has been removed from the local worker state
sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed)
// Resume the connection, and reconnect.
logger.Debug("resuming controller/worker link")
proxy.Resume()
err = w1.Worker().WaitForNextSuccessfulStatusUpdate()
require.NoError(err)
// Do something post-reconnect depending on burden case. Note in
// the default case, both worker and controller should be
// relatively in sync, so we don't worry about these
// post-reconnection assertions.
switch burdenCase {
case timeoutBurdenTypeWorker:
// If we are expecting the worker to be the source of truth of
// a connection status, ensure that our old session's
// connections are actually closed now that the worker is
// properly reporting in again.
sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed)
case timeoutBurdenTypeController:
// If we are expecting the controller to be the source of
// truth, the connection should now be forcibly closed after
// the worker gets a status change request back.
sConn.TestSendRecvFail(t)
}
// Resume the connection, and reconnect.
proxy.Resume()
time.Sleep(time.Second * 10) // Sleep to wait for worker to report back as healthy
sess = newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") // re-assign, other connection will close in t.Cleanup()
sConn = sess.Connect(ctx, t, logger)
sConn.TestSendRecvAll(t)
// Proceed with new connection test
logger.Debug("connecting to new session after resuming controller/worker link")
sess = newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") // re-assign, other connection will close in t.Cleanup()
sConn = sess.Connect(ctx, t)
sConn.TestSendRecvAll(t)
}
}
func TestWorkerSessionCleanupMultiController(t *testing.T) {
require := require.New(t)
logger := hclog.New(&hclog.LoggerOptions{
Level: hclog.Trace,
})
// ******************
// ** Controller 1 **
// ******************
conf1, err := config.DevController()
require.NoError(err)
func testWorkerSessionCleanupMulti(burdenCase timeoutBurdenType) func(t *testing.T) {
return func(t *testing.T) {
t.Parallel()
require := require.New(t)
logger := hclog.New(&hclog.LoggerOptions{
Name: t.Name(),
Level: hclog.Trace,
})
// ******************
// ** Controller 1 **
// ******************
conf1, err := config.DevController()
require.NoError(err)
pl1, err := net.Listen("tcp", "localhost:0")
require.NoError(err)
c1 := controller.NewTestController(t, &controller.TestControllerOpts{
Config: conf1,
InitialResourcesSuffix: "1234567890",
Logger: logger.Named("c1"),
PublicClusterAddr: pl1.Addr().String(),
StatusGracePeriodDuration: controllerGracePeriod(burdenCase),
})
defer c1.Shutdown()
// ******************
// ** Controller 2 **
// ******************
pl2, err := net.Listen("tcp", "localhost:0")
require.NoError(err)
c2 := c1.AddClusterControllerMember(t, &controller.TestControllerOpts{
Logger: logger.Named("c2"),
PublicClusterAddr: pl2.Addr().String(),
StatusGracePeriodDuration: controllerGracePeriod(burdenCase),
})
defer c2.Shutdown()
expectWorkers(t, c1)
expectWorkers(t, c2)
// *************
// ** Proxy 1 **
// *************
require.Len(c1.ClusterAddrs(), 1)
p1, err := dawdle.NewProxy("tcp", "", c1.ClusterAddrs()[0],
dawdle.WithListener(pl1),
dawdle.WithRbufSize(256),
dawdle.WithWbufSize(256),
)
require.NoError(err)
defer p1.Close()
require.NotEmpty(t, p1.ListenerAddr())
// *************
// ** Proxy 2 **
// *************
require.Len(c2.ClusterAddrs(), 1)
p2, err := dawdle.NewProxy("tcp", "", c2.ClusterAddrs()[0],
dawdle.WithListener(pl2),
dawdle.WithRbufSize(256),
dawdle.WithWbufSize(256),
)
require.NoError(err)
defer p2.Close()
require.NotEmpty(t, p2.ListenerAddr())
// ************
// ** Worker **
// ************
w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{
WorkerAuthKms: c1.Config().WorkerAuthKms,
InitialControllers: []string{p1.ListenerAddr(), p2.ListenerAddr()},
Logger: logger.Named("w1"),
StatusGracePeriodDuration: workerGracePeriod(burdenCase),
})
defer w1.Shutdown()
err = w1.Worker().WaitForNextSuccessfulStatusUpdate()
require.NoError(err)
err = c1.WaitForNextWorkerStatusUpdate(w1.Name())
require.NoError(err)
err = c2.WaitForNextWorkerStatusUpdate(w1.Name())
require.NoError(err)
expectWorkers(t, c1, w1)
expectWorkers(t, c2, w1)
// Use an independent context for test things that take a context so
// that we aren't tied to any timeouts in the controller, etc. This
// can interfere with some of the test operations.
ctx := context.Background()
// Connect target
client := c1.Client()
client.SetToken(c1.Token().Token)
tcl := targets.NewClient(client)
tgt, err := tcl.Read(ctx, "ttcp_1234567890")
require.NoError(err)
require.NotNil(tgt)
// Create test server, update default port on target
ts := newTestTcpServer(t, logger)
require.NotNil(ts)
defer ts.Close()
tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port()), targets.WithSessionConnectionLimit(-1))
require.NoError(err)
require.NotNil(tgt)
// Authorize and connect
sess := newTestSession(ctx, t, logger, tcl, "ttcp_1234567890")
sConn := sess.Connect(ctx, t)
// Run initial send/receive test, make sure things are working
logger.Debug("running initial send/recv test")
sConn.TestSendRecvAll(t)
// Kill connection to first controller, and run test again, should
// pass, deferring to other controller. Wait for the next
// successful status report to ensure this.
logger.Debug("pausing link to controller #1")
p1.Pause()
err = w1.Worker().WaitForNextSuccessfulStatusUpdate()
require.NoError(err)
sConn.TestSendRecvAll(t)
// Resume first controller, pause second. This one should work too.
logger.Debug("pausing link to controller #2, resuming #1")
p1.Resume()
p2.Pause()
err = w1.Worker().WaitForNextSuccessfulStatusUpdate()
require.NoError(err)
sConn.TestSendRecvAll(t)
// Kill the first controller connection again. This one should fail
// due to lack of any connection.
logger.Debug("pausing link to controller #1 again, both connections should be offline")
p1.Pause()
// Wait for failure connection state (depends on burden case)
switch burdenCase {
case timeoutBurdenTypeWorker:
// Wait on worker, then check controller
sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed)
sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusConnected)
case timeoutBurdenTypeController:
// Wait on controller, then check worker
sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed)
sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusConnected)
default:
// Should be closed on both worker and controller. Wait on
// worker then check controller.
sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed)
sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed)
}
pl1, err := net.Listen("tcp", "localhost:0")
require.NoError(err)
c1 := controller.NewTestController(t, &controller.TestControllerOpts{
Config: conf1,
InitialResourcesSuffix: "1234567890",
Logger: logger.Named("c1"),
PublicClusterAddr: pl1.Addr().String(),
})
defer c1.Shutdown()
// Run send/receive test again to check expected connection-level
// behavior
if burdenCase == timeoutBurdenTypeController {
// Burden on controller, should be successful until connection
// resumes
sConn.TestSendRecvAll(t)
} else {
// Connection should die in other cases
sConn.TestSendRecvFail(t)
}
// ******************
// ** Controller 2 **
// ******************
pl2, err := net.Listen("tcp", "localhost:0")
require.NoError(err)
c2 := c1.AddClusterControllerMember(t, &controller.TestControllerOpts{
Logger: c1.Config().Logger.ResetNamed("c2"),
PublicClusterAddr: pl2.Addr().String(),
})
defer c2.Shutdown()
expectWorkers(t, c1)
expectWorkers(t, c2)
// *************
// ** Proxy 1 **
// *************
require.Len(c1.ClusterAddrs(), 1)
p1, err := dawdle.NewProxy("tcp", "", c1.ClusterAddrs()[0],
dawdle.WithListener(pl1),
dawdle.WithRbufSize(512),
dawdle.WithWbufSize(512),
)
require.NoError(err)
defer p1.Close()
require.NotEmpty(t, p1.ListenerAddr())
// *************
// ** Proxy 2 **
// *************
require.Len(c2.ClusterAddrs(), 1)
p2, err := dawdle.NewProxy("tcp", "", c2.ClusterAddrs()[0],
dawdle.WithListener(pl2),
dawdle.WithRbufSize(512),
dawdle.WithWbufSize(512),
)
require.NoError(err)
defer p2.Close()
require.NotEmpty(t, p2.ListenerAddr())
// ************
// ** Worker **
// ************
w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{
WorkerAuthKms: c1.Config().WorkerAuthKms,
InitialControllers: []string{p1.ListenerAddr(), p2.ListenerAddr()},
Logger: logger.Named("w1"),
})
defer w1.Shutdown()
time.Sleep(10 * time.Second)
expectWorkers(t, c1, w1)
expectWorkers(t, c2, w1)
// Use an independent context for test things that take a context so
// that we aren't tied to any timeouts in the controller, etc. This
// can interfere with some of the test operations.
ctx := context.Background()
// Connect target
client := c1.Client()
client.SetToken(c1.Token().Token)
tcl := targets.NewClient(client)
tgt, err := tcl.Read(ctx, "ttcp_1234567890")
require.NoError(err)
require.NotNil(tgt)
// Finally resume both, try again. Should behave as per normal.
logger.Debug("resuming connections to both controllers")
p1.Resume()
p2.Resume()
err = w1.Worker().WaitForNextSuccessfulStatusUpdate()
require.NoError(err)
// Do something post-reconnect depending on burden case. Note in
// the default case, both worker and controller should be
// relatively in sync, so we don't worry about these
// post-reconnection assertions.
switch burdenCase {
case timeoutBurdenTypeWorker:
// If we are expecting the worker to be the source of truth of
// a connection status, ensure that our old session's
// connections are actually closed now that the worker is
// properly reporting in again.
sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed)
case timeoutBurdenTypeController:
// If we are expecting the controller to be the source of
// truth, the connection should now be forcibly closed after
// the worker gets a status change request back.
sConn.TestSendRecvFail(t)
}
// Create test server, update default port on target
ts := newTestTcpServer(t, logger)
require.NotNil(ts)
defer ts.Close()
tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port()))
require.NoError(err)
require.NotNil(tgt)
// Authorize and connect
sess := newTestSession(ctx, t, logger, tcl, "ttcp_1234567890")
sConn := sess.Connect(ctx, t, logger)
// Run initial send/receive test, make sure things are working
sConn.TestSendRecvAll(t)
// Kill connection to first controller, and run test again, should
// pass, deferring to other controller.
p1.Pause()
sConn.TestSendRecvAll(t)
// Resume first controller, pause second. This one should work too.
p1.Resume()
p2.Pause()
sConn.TestSendRecvAll(t)
// Kill the first controller connection again. This one should fail
// due to lack of any connection.
p1.Pause()
sConn.TestSendRecvFail(t)
// Assert we have no connections left (should be default behavior)
sess.TestNoConnectionsLeft(t)
// Assert connection has been removed from the local worker state
sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed)
// Finally resume both, try again. Should behave as per normal.
p1.Resume()
p2.Resume()
time.Sleep(time.Second * 10) // Sleep to wait for worker to report back as healthy
sess = newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") // re-assign, other connection will close in t.Cleanup()
sConn = sess.Connect(ctx, t, logger)
sConn.TestSendRecvAll(t)
// Proceed with new connection test
logger.Debug("connecting to new session after resuming controller/worker link")
sess = newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") // re-assign, other connection will close in t.Cleanup()
sConn = sess.Connect(ctx, t)
sConn.TestSendRecvAll(t)
}
}
// testSession represents an authorized session.
@ -392,10 +579,73 @@ func (s *testSession) connect(ctx context.Context, t *testing.T) net.Conn {
return websocket.NetConn(ctx, conn, websocket.MessageBinary)
}
// TestNoConnectionsLeft asserts that there are no connections left.
func (s *testSession) TestNoConnectionsLeft(t *testing.T) {
// ExpectConnectionStateOnController waits until all connections in a
// session have transitioned to a particular state on the controller.
func (s *testSession) ExpectConnectionStateOnController(
ctx context.Context,
t *testing.T,
tc *controller.TestController,
expectState session.ConnectionStatus,
) {
t.Helper()
require.Zero(t, s.connectionsLeft)
require := require.New(t)
assert := assert.New(t)
ctx, cancel := context.WithTimeout(ctx, expectConnectionStateOnControllerTimeout)
defer cancel()
// This is just for initialization of the actual state set.
const sessionStatusUnknown session.ConnectionStatus = "unknown"
// Get all connections for the session on the controller.
sessionRepo, err := tc.Controller().SessionRepoFn()
require.NoError(err)
conns, err := sessionRepo.ListConnectionsBySessionId(ctx, s.sessionId)
require.NoError(err)
// To avoid misleading passing tests, we require this test be used
// with sessions with connections..
require.Greater(len(conns), 0, "should have at least one connection")
// Make a set of states, 1 per connection
actualStates := make([]session.ConnectionStatus, len(conns))
for i := range actualStates {
actualStates[i] = sessionStatusUnknown
}
// Make expect set for comparison
expectStates := make([]session.ConnectionStatus, len(conns))
for i := range expectStates {
expectStates[i] = expectState
}
for {
if ctx.Err() != nil {
break
}
for i, conn := range conns {
_, states, err := sessionRepo.LookupConnection(ctx, conn.PublicId, nil)
require.NoError(err)
// Look at the first state in the returned list, which will
// be the most recent state.
actualStates[i] = states[0].Status
}
if reflect.DeepEqual(expectStates, actualStates) {
break
}
time.Sleep(time.Second)
}
// "non-fatal" assert here, so that we can surface both timeouts
// and invalid state
assert.NoError(ctx.Err())
// Assert
require.Equal(expectStates, actualStates)
s.logger.Debug("successfully asserted all connection states on controller", "expected_states", expectStates, "actual_states", actualStates)
}
// ExpectConnectionStateOnWorker waits until all connections in a
@ -491,7 +741,6 @@ type testSessionConnection struct {
func (s *testSession) Connect(
ctx context.Context,
t *testing.T, // Just to add cleanup
logger hclog.Logger,
) *testSessionConnection {
t.Helper()
require := require.New(t)
@ -504,7 +753,7 @@ func (s *testSession) Connect(
return &testSessionConnection{
conn: conn,
logger: logger,
logger: s.logger,
}
}
@ -518,6 +767,12 @@ func (s *testSession) Connect(
func (c *testSessionConnection) testSendRecv(t *testing.T) bool {
t.Helper()
require := require.New(t)
// This is a fairly arbitrary value, as the send/recv is
// instantaneous. The main key here is just to make sure that we do
// it a reasonable amount of times to know the connection is
// stable.
const testSendRecvSendMax = 100
for i := uint32(0); i < testSendRecvSendMax; i++ {
// Shuttle over the sequence number as base64.
err := binary.Write(c.conn, binary.LittleEndian, i)
@ -536,7 +791,7 @@ func (c *testSessionConnection) testSendRecv(t *testing.T) bool {
var j uint32
err = binary.Read(c.conn, binary.LittleEndian, &j)
if err != nil {
c.logger.Debug("received error during read", "err", err)
c.logger.Debug("received error during read", "err", err, "num_successfully_sent", i)
if errors.Is(err, net.ErrClosed) ||
errors.Is(err, io.EOF) ||
errors.Is(err, websocket.CloseError{Code: websocket.StatusPolicyViolation, Reason: "timed out"}) {
@ -549,9 +804,10 @@ func (c *testSessionConnection) testSendRecv(t *testing.T) bool {
require.Equal(j, i)
// Sleep 1s
time.Sleep(time.Second)
// time.Sleep(time.Second)
}
c.logger.Debug("finished send/recv successfully", "num_successfully_sent", testSendRecvSendMax)
return true
}
@ -560,6 +816,7 @@ func (c *testSessionConnection) testSendRecv(t *testing.T) bool {
func (c *testSessionConnection) TestSendRecvAll(t *testing.T) {
t.Helper()
require.True(t, c.testSendRecv(t))
c.logger.Debug("successfully asserted send/recv as passing")
}
// TestSendRecvFail asserts that we were able to send/recv all pings
@ -567,6 +824,7 @@ func (c *testSessionConnection) TestSendRecvAll(t *testing.T) {
func (c *testSessionConnection) TestSendRecvFail(t *testing.T) {
t.Helper()
require.False(t, c.testSendRecv(t))
c.logger.Debug("successfully asserted send/recv as failing")
}
type testTcpServer struct {

Loading…
Cancel
Save