From 5a70875726ddeecfa3a7a9b876f27a3ce30701ed Mon Sep 17 00:00:00 2001 From: Chris Marchesi Date: Wed, 14 Jul 2021 09:14:24 -0700 Subject: [PATCH] internal/servers/controller: Worker failure connection cleanup (#1340) This commit adds the support to do the following: * Mark connections for non-reporting workers as closed. This is the controller counterpart to the worker functionality (see #1330). This is written as a scheduled job that does the work DB-side in a single atomic query. * Works to reconcile states if such a broken controller-worker connection resumes and a worker reports a connection as connected that should be disconnected. In this case, the controller will send an update request, and the worker will honor it and terminate the connection. * Further refinement of the grace period setting has been added here. We have converged on the current server "liveness" setting as our default here, which is half of the previous 30s (15 seconds, in other words). Additionally, this is now configurable on the controller and worker side, with the caveat that it's currently impossible to do so in config as the setting has been untagged in HCL. This is exposed so that we can run some sophisticated testing scenarios where we skew the grace period to either the controller or worker to ensure the aforementioned reconciliation works. * Some repository functions have been added to support the new functionality, in addition to some test code to the worker to allow querying of session state while testing. * Finally, we've added some add timestamp subtraction functions as well, basically serving as the opposite of the addition functions. --- internal/cmd/base/servers.go | 72 ++ internal/cmd/commands/dev/dev.go | 4 + internal/cmd/commands/server/server.go | 4 + internal/cmd/config/config.go | 14 + .../postgres/12/01_timestamp_sub_funcs.up.sql | 23 + .../12/01_timestamp_sub_funcs_test.go | 61 ++ internal/db/schema/postgres_migration.gen.go | 23 +- internal/servers/controller/controller.go | 30 +- .../handlers/workers/worker_service.go | 70 +- .../servers/controller/session_cleanup_job.go | 133 ++++ .../controller/session_cleanup_job_test.go | 170 +++++ internal/servers/controller/testing.go | 77 ++ internal/servers/options.go | 5 +- internal/servers/repository.go | 23 +- internal/servers/repository_test.go | 79 ++ internal/servers/worker/status.go | 147 ++-- internal/servers/worker/status_test.go | 56 ++ internal/servers/worker/testing.go | 16 +- internal/servers/worker/worker.go | 2 +- internal/session/query.go | 101 ++- internal/session/repository_connection.go | 144 +++- .../session/repository_connection_test.go | 445 ++++++++++- .../tests/cluster/session_cleanup_test.go | 688 ++++++++++++------ 23 files changed, 2081 insertions(+), 306 deletions(-) create mode 100644 internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs.up.sql create mode 100644 internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs_test.go create mode 100644 internal/servers/controller/session_cleanup_job.go create mode 100644 internal/servers/controller/session_cleanup_job_test.go create mode 100644 internal/servers/worker/status_test.go diff --git a/internal/cmd/base/servers.go b/internal/cmd/base/servers.go index 3e376c0ee9..0560d8a57f 100644 --- a/internal/cmd/base/servers.go +++ b/internal/cmd/base/servers.go @@ -15,6 +15,7 @@ import ( "strings" "sync" "syscall" + "time" "github.com/armon/go-metrics" berrors "github.com/hashicorp/boundary/internal/errors" @@ -24,6 +25,7 @@ import ( "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/observability/event" + "github.com/hashicorp/boundary/internal/servers" "github.com/hashicorp/boundary/internal/types/scope" "github.com/hashicorp/boundary/version" "github.com/hashicorp/go-hclog" @@ -40,6 +42,23 @@ import ( "google.golang.org/grpc/grpclog" ) +const ( + // defaultStatusGracePeriod is the default status grace period, or the period + // of time that we will go without a status report before we start + // disconnecting and marking connections as closed. This is tied to the + // server default liveness setting, a related value. See the servers package + // for more details. + defaultStatusGracePeriod = servers.DefaultLiveness + + // statusGracePeriodEnvVar is the environment variable that can be used to + // configure the status grace period. This setting is provided in seconds, + // and can never be lower than the default status grace period defined above. + // + // TODO: This value is temporary, it will be removed once we have a better + // story/direction on attributes and system defaults. + statusGracePeriodEnvVar = "BOUNDARY_STATUS_GRACE_PERIOD" +) + type Server struct { *Command @@ -101,6 +120,11 @@ type Server struct { DevDatabaseCleanupFunc func() error Database *gorm.DB + + // StatusGracePeriodDuration represents the period of time (as a + // duration) that the controller will wait before marking + // connections from a disconnected worker as invalid. + StatusGracePeriodDuration time.Duration } func NewServer(cmd *Command) *Server { @@ -675,3 +699,51 @@ func MakeSighupCh() chan struct{} { }() return resultCh } + +// SetStatusGracePeriodDuration sets the value for +// StatusGracePeriodDuration. +// +// The grace period is the length of time we allow connections to run +// on a worker in the event of an error sending status updates. The +// period is defined the length of time since the last successful +// update. +// +// The setting is derived from one of the following, in order: +// +// * Via the supplied value if non-zero. +// * BOUNDARY_STATUS_GRACE_PERIOD, if defined, can be set to an +// integer value to define the setting. +// * If either of these is missing, the default is used. See the +// defaultStatusGracePeriod value for the default value. +// +// The minimum setting for this value is the default setting. Values +// below this will be reset to the default. +func (s *Server) SetStatusGracePeriodDuration(value time.Duration) { + var result time.Duration + switch { + case value > 0: + result = value + case os.Getenv(statusGracePeriodEnvVar) != "": + // TODO: See the description of the constant for more details on + // this env var + v := os.Getenv(statusGracePeriodEnvVar) + n, err := strconv.Atoi(v) + if err != nil { + s.Logger.Error(fmt.Sprintf("could not read setting for %s", statusGracePeriodEnvVar), + "err", err, + "value", v, + ) + break + } + + result = time.Second * time.Duration(n) + } + + if result < defaultStatusGracePeriod { + s.Logger.Debug("invalid grace period setting or none provided, using default", "value", result, "default", defaultStatusGracePeriod) + result = defaultStatusGracePeriod + } + + s.Logger.Debug("session cleanup in effect, connections will be terminated if status reports cannot be made", "grace_period", result) + s.StatusGracePeriodDuration = result +} diff --git a/internal/cmd/commands/dev/dev.go b/internal/cmd/commands/dev/dev.go index eaecebe511..742e28b7ee 100644 --- a/internal/cmd/commands/dev/dev.go +++ b/internal/cmd/commands/dev/dev.go @@ -425,6 +425,10 @@ func (c *Command) Run(args []string) int { return base.CommandCliError } + // Initialize status grace period (0 denotes using env or default + // here) + c.SetStatusGracePeriodDuration(0) + base.StartMemProfiler(c.Logger) if err := c.SetupMetrics(c.UI, c.Config.Telemetry); err != nil { diff --git a/internal/cmd/commands/server/server.go b/internal/cmd/commands/server/server.go index f663a609fc..ef4a498f68 100644 --- a/internal/cmd/commands/server/server.go +++ b/internal/cmd/commands/server/server.go @@ -156,6 +156,10 @@ func (c *Command) Run(args []string) int { return base.CommandUserError } + // Initialize status grace period (0 denotes using env or default + // here) + c.SetStatusGracePeriodDuration(0) + base.StartMemProfiler(c.Logger) if !c.skipMetrics { diff --git a/internal/cmd/config/config.go b/internal/cmd/config/config.go index b51bbb7558..551cdb29d4 100644 --- a/internal/cmd/config/config.go +++ b/internal/cmd/config/config.go @@ -119,6 +119,13 @@ type Controller struct { // denoted by time.Duration AuthTokenTimeToStale interface{} `hcl:"auth_token_time_to_stale"` AuthTokenTimeToStaleDuration time.Duration + + // StatusGracePeriod represents the period of time (as a duration) that the + // controller will wait before marking connections from a disconnected worker + // as invalid. + // + // TODO: This field is currently internal. + StatusGracePeriodDuration time.Duration `hcl:"-"` } type Worker struct { @@ -132,6 +139,13 @@ type Worker struct { // key=value syntax. This is trued up in the Parse function below. TagsRaw interface{} `hcl:"tags"` Tags map[string][]string `hcl:"-"` + + // StatusGracePeriod represents the period of time (as a duration) that the + // worker will wait before disconnecting connections if it cannot make a + // status report to a controller. + // + // TODO: This field is currently internal. + StatusGracePeriodDuration time.Duration `hcl:"-"` } type Database struct { diff --git a/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs.up.sql b/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs.up.sql new file mode 100644 index 0000000000..09b2468802 --- /dev/null +++ b/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs.up.sql @@ -0,0 +1,23 @@ +begin; + + create function wt_sub_seconds(sec integer, ts timestamp with time zone) + returns timestamp with time zone + as $$ + select ts - sec * '1 second'::interval; + $$ language sql + stable + returns null on null input; + comment on function wt_add_seconds is + 'wt_sub_seconds returns ts - sec.'; + + create function wt_sub_seconds_from_now(sec integer) + returns timestamp with time zone + as $$ + select wt_sub_seconds(sec, current_timestamp); + $$ language sql + stable + returns null on null input; + comment on function wt_add_seconds_to_now is + 'wt_sub_seconds_from_now returns current_timestamp - sec.'; + +commit; diff --git a/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs_test.go b/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs_test.go new file mode 100644 index 0000000000..3eda4f90f0 --- /dev/null +++ b/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs_test.go @@ -0,0 +1,61 @@ +package migration + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/hashicorp/boundary/internal/db/schema" + "github.com/hashicorp/boundary/internal/docker" + "github.com/stretchr/testify/require" +) + +const targetMigration = 12001 + +func TestWtSubSeconds(t *testing.T) { + t.Parallel() + require := require.New(t) + ctx := context.Background() + d := testSetupDb(ctx, t) + + // Test by subtracing a day from the test date + sourceTime, err := time.Parse(time.RFC3339, "2006-01-02T15:04:05+07:00") + require.NoError(err) + expectedTime := sourceTime.Add(time.Second * -86400) + + var actualTime time.Time + row := d.QueryRowContext(ctx, "select wt_sub_seconds($1, $2)", 86400, sourceTime) + require.NoError(row.Scan(&actualTime)) + require.True(expectedTime.Equal(actualTime)) +} + +func testSetupDb(ctx context.Context, t *testing.T) *sql.DB { + t.Helper() + require := require.New(t) + + dialect := "postgres" + + c, u, _, err := docker.StartDbInDocker(dialect) + require.NoError(err) + t.Cleanup(func() { + require.NoError(c()) + }) + d, err := sql.Open(dialect, u) + require.NoError(err) + + oState := schema.TestCloneMigrationStates(t) + nState := schema.TestCreatePartialMigrationState(oState["postgres"], targetMigration) + oState["postgres"] = nState + + m, err := schema.NewManager(ctx, dialect, d, schema.WithMigrationStates(oState)) + require.NoError(err) + + require.NoError(m.RollForward(ctx)) + state, err := m.CurrentState(ctx) + require.NoError(err) + require.Equal(targetMigration, state.DatabaseSchemaVersion) + require.False(state.Dirty) + + return d +} diff --git a/internal/db/schema/postgres_migration.gen.go b/internal/db/schema/postgres_migration.gen.go index d31e3d11ae..8b501a12c5 100644 --- a/internal/db/schema/postgres_migration.gen.go +++ b/internal/db/schema/postgres_migration.gen.go @@ -4,7 +4,7 @@ package schema func init() { migrationStates["postgres"] = migrationState{ - binarySchemaVersion: 11001, + binarySchemaVersion: 12001, upMigrations: map[int][]byte{ 1: []byte(` create domain wt_public_id as text @@ -6100,6 +6100,27 @@ alter table server foreign key (type) references server_type_enm(name) on update cascade on delete restrict; +`), + 12001: []byte(` +create function wt_sub_seconds(sec integer, ts timestamp with time zone) + returns timestamp with time zone + as $$ + select ts - sec * '1 second'::interval; + $$ language sql + stable + returns null on null input; + comment on function wt_add_seconds is + 'wt_sub_seconds returns ts - sec.'; + + create function wt_sub_seconds_from_now(sec integer) + returns timestamp with time zone + as $$ + select wt_sub_seconds(sec, current_timestamp); + $$ language sql + stable + returns null on null input; + comment on function wt_add_seconds_to_now is + 'wt_sub_seconds_from_now returns current_timestamp - sec.'; `), 2001: []byte(` -- log_migration entries represent logs generated during migrations diff --git a/internal/servers/controller/controller.go b/internal/servers/controller/controller.go index f4366ba004..a9029547b4 100644 --- a/internal/servers/controller/controller.go +++ b/internal/servers/controller/controller.go @@ -38,7 +38,7 @@ type Controller struct { workerAuthCache *cache.Cache - // Used for testing + // Used for testing and tracking worker health workerStatusUpdateTimes *sync.Map // Repo factory methods @@ -117,7 +117,9 @@ func New(conf *Config) (*Controller, error) { jobRepoFn := func() (*job.Repository, error) { return job.NewRepository(dbase, dbase, c.kms) } - c.scheduler, err = scheduler.New(c.conf.RawConfig.Controller.Name, jobRepoFn, c.logger) + // TODO: the RunJobsLimit is temporary until a better fix gets in. This + // currently caps the scheduler at running 10 jobs per interval. + c.scheduler, err = scheduler.New(c.conf.RawConfig.Controller.Name, jobRepoFn, c.logger, scheduler.WithRunJobsLimit(10)) if err != nil { return nil, fmt.Errorf("error creating new scheduler: %w", err) } @@ -182,7 +184,29 @@ func (c *Controller) Start() error { func (c *Controller) registerJobs() error { rw := db.New(c.conf.Database) - return vault.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms, c.logger) + if err := vault.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms, c.logger); err != nil { + return err + } + + if err := c.registerSessionCleanupJob(); err != nil { + return err + } + + return nil +} + +// registerSessionCleanupJob is a helper method to abstract +// registering the session cleanup job specifically. +func (c *Controller) registerSessionCleanupJob() error { + sessionCleanupJob, err := newSessionCleanupJob(c.logger, c.SessionRepoFn, int(c.conf.StatusGracePeriodDuration.Seconds())) + if err != nil { + return fmt.Errorf("error creating session cleanup job: %w", err) + } + if err = c.scheduler.RegisterJob(c.baseContext, sessionCleanupJob); err != nil { + return fmt.Errorf("error registering session cleanup job: %w", err) + } + + return nil } func (c *Controller) Shutdown(serversOnly bool) error { diff --git a/internal/servers/controller/handlers/workers/worker_service.go b/internal/servers/controller/handlers/workers/worker_service.go index 4135ccf7b0..56520a03f8 100644 --- a/internal/servers/controller/handlers/workers/worker_service.go +++ b/internal/servers/controller/handlers/workers/worker_service.go @@ -76,12 +76,19 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques Controllers: controllers, } - // TODO (jeff, 05-2021): We should possibly list all connections here with - // their statuses, and check those against the found connections below. If - // any we think should be closed are marked open by the worker, the worker - // should be told to close them. - - var foundConns []string + var ( + // For tracking the reported open connections. + reportedOpenConns []string + // For tracking the session IDs we've already requested + // cancellation for. We won't need to add connection cancel + // requests for these because canceling the session terminates the + // connections. + requestedSessionCancelIds []string + ) + + // This is a map of all sessions and their statuses. We keep track of + // this for easy lookup if we need to make change requests. + sessionStatuses := make(map[string]pbs.SESSIONSTATUS) for _, jobStatus := range req.GetJobs() { switch jobStatus.Job.GetType() { @@ -92,6 +99,9 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques return nil, status.Error(codes.Internal, "Error getting session info at status time") } + // Record status. + sessionStatuses[si.GetSessionId()] = si.Status + // Check connections before potentially bypassing the rest of the // logic in the switch on si.Status. sessConns := si.GetConnections() @@ -103,7 +113,7 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques // report as found, so that we should attempt to close it. // Note that unspecified is the default state for the enum // but it's not ever explicitly set by us. - foundConns = append(foundConns, conn.GetConnectionId()) + reportedOpenConns = append(reportedOpenConns, conn.GetConnectionId()) } } @@ -127,7 +137,7 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques } // If the session from the DB is in canceling status, and we're // here, it means the job is in pending or active; cancel it. If - // it's in termianted status something went wrong and we're + // it's in terminated status something went wrong and we're // mismatched, so ensure we cancel it also. currState := sessionInfo.States[0].Status if currState.ProtoVal() != si.Status { @@ -148,13 +158,53 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques }, RequestType: pbs.CHANGETYPE_CHANGETYPE_UPDATE_STATE, }) + // Log the session ID so we don't add a duplicate change + // request on connection normalization. + requestedSessionCancelIds = append(requestedSessionCancelIds, sessionId) } } } } - // Run our connection cleanup function - closedConns, err := sessRepo.CloseDeadConnectionsOnWorkerReport(ctx, req.Worker.PrivateId, foundConns) + // Normalize the current state of connections on the worker side + // with the data from the controller. In other words, if one of our + // found connections isn't supposed to be alive still, kill it. + // + // This is separate from the above session normalization and is + // additive to it, we don't add sessions that have already been + // added there as canceling sessions already closes the + // connections. + shouldCloseConnections, err := sessRepo.ShouldCloseConnectionsOnWorker(ctx, reportedOpenConns, requestedSessionCancelIds) + if err != nil { + return nil, status.Errorf(codes.Internal, "Error fetching connections that should be closed: %v", err) + } + + for sessionId, connIds := range shouldCloseConnections { + var connChanges []*pbs.Connection + for _, connId := range connIds { + connChanges = append(connChanges, &pbs.Connection{ + ConnectionId: connId, + Status: session.StatusClosed.ProtoVal(), + }) + } + + ret.JobsRequests = append(ret.JobsRequests, &pbs.JobChangeRequest{ + Job: &pbs.Job{ + Type: pbs.JOBTYPE_JOBTYPE_SESSION, + JobInfo: &pbs.Job_SessionInfo{ + SessionInfo: &pbs.SessionJobInfo{ + SessionId: sessionId, + Status: sessionStatuses[sessionId], + Connections: connChanges, + }, + }, + }, + RequestType: pbs.CHANGETYPE_CHANGETYPE_UPDATE_STATE, + }) + } + + // Run our controller-side cleanup function. + closedConns, err := sessRepo.CloseDeadConnectionsForWorker(ctx, req.Worker.PrivateId, reportedOpenConns) if err != nil { return nil, status.Errorf(codes.Internal, "Error closing dead conns for worker %s: %v", req.Worker.PrivateId, err) } diff --git a/internal/servers/controller/session_cleanup_job.go b/internal/servers/controller/session_cleanup_job.go new file mode 100644 index 0000000000..fd59e6fa8d --- /dev/null +++ b/internal/servers/controller/session_cleanup_job.go @@ -0,0 +1,133 @@ +package controller + +import ( + "context" + "fmt" + "time" + + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/scheduler" + "github.com/hashicorp/boundary/internal/servers/controller/common" + "github.com/hashicorp/boundary/internal/session" + "github.com/hashicorp/go-hclog" +) + +// sessionCleanupJob defines a periodic job that monitors workers for +// loss of connection and terminates connections on workers that have +// not sent a heartbeat in a significant period of time. +// +// Relevant connections are simply marked as disconnected in the +// database. Connections will be independently terminated by the +// worker, or the event of a synchronization issue between the two, +// the controller will win out and order that the connections be +// closed on the worker. +type sessionCleanupJob struct { + logger hclog.Logger + sessionRepoFn common.SessionRepoFactory + + // The amount of time to give disconnected workers before marking + // their connections as closed. + gracePeriod int + + // The total number of connections closed in the last run. + totalClosed int +} + +// newSessionCleanupJob instantiates the session cleanup job. +func newSessionCleanupJob( + logger hclog.Logger, + sessionRepoFn common.SessionRepoFactory, + gracePeriod int, +) (*sessionCleanupJob, error) { + const op = "controller.newNewSessionCleanupJob" + switch { + case logger == nil: + return nil, errors.New(errors.InvalidParameter, op, "missing logger") + case sessionRepoFn == nil: + return nil, errors.New(errors.InvalidParameter, op, "missing sessionRepoFn") + case gracePeriod < session.DeadWorkerConnCloseMinGrace: + return nil, errors.New( + errors.InvalidParameter, op, fmt.Sprintf("invalid gracePeriod, must be greater than %d", session.DeadWorkerConnCloseMinGrace)) + } + + return &sessionCleanupJob{ + logger: logger, + sessionRepoFn: sessionRepoFn, + gracePeriod: gracePeriod, + }, nil +} + +// Name returns a short, unique name for the job. +func (j *sessionCleanupJob) Name() string { return "session_cleanup" } + +// Description returns the description for the job. +func (j *sessionCleanupJob) Description() string { + return "Clean up session connections from disconnected workers" +} + +// NextRunIn returns the next run time after a job is completed. +// +// The next run time is defined for sessionCleanupJob as one second. +// This is because the job should run continuously to terminate +// connections as soon as a worker has not reported in for a long +// enough time. Only one job will ever run at once, so there is no +// reason why it cannot run again immediately. +func (j *sessionCleanupJob) NextRunIn() (time.Duration, error) { return time.Second, nil } + +// Status returns the status of the running job. +func (j *sessionCleanupJob) Status() scheduler.JobStatus { + return scheduler.JobStatus{ + Completed: j.totalClosed, + Total: j.totalClosed, + } +} + +// Run executes the job. +func (j *sessionCleanupJob) Run(ctx context.Context) error { + const op = "controller.(sessionCleanupJob).Run" + j.logger.Debug( + "starting job", + "op", op, + ) + j.totalClosed = 0 + + // Load repos. + sessionRepo, err := j.sessionRepoFn() + if err != nil { + return errors.Wrap(err, op, errors.WithMsg("error getting session repo")) + } + + // Run the atomic dead worker cleanup job. + results, err := sessionRepo.CloseConnectionsForDeadWorkers(ctx, j.gracePeriod) + if err != nil { + return errors.Wrap(err, op) + } + + if len(results) < 1 { + j.logger.Debug( + "all workers OK, no connections to close", + "op", op, + ) + } else { + for _, result := range results { + // Log a closed connection message for each worker as a warning + j.logger.Warn( + "worker has not reported status within acceptable grace period, all connections closed", + "op", op, + "private_id", result.ServerId, + "update_time", result.LastUpdateTime, + "grace_period_seconds", j.gracePeriod, + "number_connections_closed", result.NumberConnectionsClosed, + ) + + j.totalClosed += result.NumberConnectionsClosed + } + } + + j.logger.Debug( + "job finished", + "op", op, + "total_connections_closed", j.totalClosed, + ) + return nil +} diff --git a/internal/servers/controller/session_cleanup_job_test.go b/internal/servers/controller/session_cleanup_job_test.go new file mode 100644 index 0000000000..b83dd67534 --- /dev/null +++ b/internal/servers/controller/session_cleanup_job_test.go @@ -0,0 +1,170 @@ +package controller + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/scheduler" + "github.com/hashicorp/boundary/internal/servers" + "github.com/hashicorp/boundary/internal/session" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// assert the interface +var _ = scheduler.Job(new(sessionCleanupJob)) + +// This test has been largely adapted from +// TestRepository_CloseDeadConnectionsOnWorker in +// internal/session/repository_connection_test.go. +func TestSessionCleanupJob(t *testing.T) { + t.Parallel() + require, assert := require.New(t), assert.New(t) + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + serversRepo, err := servers.NewRepository(rw, rw, kms) + require.NoError(err) + sessionRepo, err := session.NewRepository(rw, rw, kms) + require.NoError(err) + ctx := context.Background() + numConns := 12 + + // Create two "workers". One will remain untouched while the other "goes + // away and comes back" (worker 2). + worker1 := session.TestWorker(t, conn, wrapper, session.WithServerId("worker1")) + worker2 := session.TestWorker(t, conn, wrapper, session.WithServerId("worker2")) + + // Create a few sessions on each, activate, and authorize a connection + var connIds []string + connIdsByWorker := make(map[string][]string) + for i := 0; i < numConns; i++ { + serverId := worker1.PrivateId + if i%2 == 0 { + serverId = worker2.PrivateId + } + sess := session.TestDefaultSession(t, conn, wrapper, iamRepo, session.WithServerId(serverId), session.WithDbOpts(db.WithSkipVetForWrite(true))) + sess, _, err = sessionRepo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, serverId, "worker", []byte("foo")) + require.NoError(err) + c, cs, _, err := sessionRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) + require.NoError(err) + require.Len(cs, 1) + require.Equal(session.StatusAuthorized, cs[0].Status) + connIds = append(connIds, c.GetPublicId()) + if i%2 == 0 { + connIdsByWorker[worker2.PrivateId] = append(connIdsByWorker[worker2.PrivateId], c.GetPublicId()) + } else { + connIdsByWorker[worker1.PrivateId] = append(connIdsByWorker[worker1.PrivateId], c.GetPublicId()) + } + } + + // Mark half of the connections connected and leave the others authorized. + // This is just to ensure we have a spread when we test it out. + for i, connId := range connIds { + if i%2 == 0 { + _, cs, err := sessionRepo.ConnectConnection(ctx, session.ConnectWith{ + ConnectionId: connId, + ClientTcpAddress: "127.0.0.1", + ClientTcpPort: 22, + EndpointTcpAddress: "127.0.0.1", + EndpointTcpPort: 22, + }) + require.NoError(err) + require.Len(cs, 2) + var foundAuthorized, foundConnected bool + for _, status := range cs { + if status.Status == session.StatusAuthorized { + foundAuthorized = true + } + if status.Status == session.StatusConnected { + foundConnected = true + } + } + require.True(foundAuthorized) + require.True(foundConnected) + } + } + + // Create the job. + job, err := newSessionCleanupJob( + hclog.New(&hclog.LoggerOptions{Level: hclog.Trace}), + func() (*session.Repository, error) { return sessionRepo, nil }, + session.DeadWorkerConnCloseMinGrace, + ) + require.NoError(err) + + // sleep the status grace period. + time.Sleep(time.Second * time.Duration(session.DeadWorkerConnCloseMinGrace)) + + // Push an upsert to the first worker so that its status has been + // updated. + _, rowsUpdated, err := serversRepo.UpsertServer(ctx, worker1, []servers.Option{}...) + require.NoError(err) + require.Equal(1, rowsUpdated) + + // Run the job. + require.NoError(job.Run(ctx)) + + // Assert connection state on both workers. + assertConnections := func(workerId string, closed bool) { + connIds, ok := connIdsByWorker[workerId] + require.True(ok) + require.Len(connIds, 6) + for _, connId := range connIds { + _, states, err := sessionRepo.LookupConnection(ctx, connId, nil) + require.NoError(err) + var foundClosed bool + for _, state := range states { + if state.Status == session.StatusClosed { + foundClosed = true + break + } + } + assert.Equal(closed, foundClosed) + } + } + + // Assert that all connections on the second worker are closed + assertConnections(worker2.PrivateId, true) + // Assert that all connections on the first worker are still open + assertConnections(worker1.PrivateId, false) +} + +func TestSessionCleanupJobNewJobErr(t *testing.T) { + t.Parallel() + const op = "controller.newNewSessionCleanupJob" + require := require.New(t) + + job, err := newSessionCleanupJob(nil, nil, 0) + require.Equal(err, errors.E( + errors.WithCode(errors.InvalidParameter), + errors.WithOp(op), + errors.WithMsg("missing logger"), + )) + require.Nil(job) + + job, err = newSessionCleanupJob(hclog.New(nil), nil, 0) + require.Equal(err, errors.E( + errors.WithCode(errors.InvalidParameter), + errors.WithOp(op), + errors.WithMsg("missing sessionRepoFn"), + )) + require.Nil(job) + + job, err = newSessionCleanupJob(hclog.New(nil), func() (*session.Repository, error) { return nil, nil }, 0) + require.Equal(err, errors.E( + errors.WithCode(errors.InvalidParameter), + errors.WithOp(op), + errors.WithMsg(fmt.Sprintf("invalid gracePeriod, must be greater than %d", session.DeadWorkerConnCloseMinGrace)), + )) + require.Nil(job) +} diff --git a/internal/servers/controller/testing.go b/internal/servers/controller/testing.go index 620dccfae5..67be160392 100644 --- a/internal/servers/controller/testing.go +++ b/internal/servers/controller/testing.go @@ -8,6 +8,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/api/authmethods" @@ -371,6 +372,10 @@ type TestControllerOpts struct { // A cluster address for overriding the advertised controller listener // (overrides address provided in config, if any) PublicClusterAddr string + + // The amount of time to wait before marking connections as closed when a + // worker has not reported in + StatusGracePeriodDuration time.Duration } func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { @@ -456,6 +461,9 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { t.Fatal(err) } + // Initialize status grace period + tc.b.SetStatusGracePeriodDuration(opts.StatusGracePeriodDuration) + if opts.Config.Controller == nil { opts.Config.Controller = new(config.Controller) } @@ -615,6 +623,7 @@ func (tc *TestController) AddClusterControllerMember(t *testing.T, opts *TestCon DisableKmsKeyCreation: true, DisableAuthMethodCreation: true, PublicClusterAddr: opts.PublicClusterAddr, + StatusGracePeriodDuration: opts.StatusGracePeriodDuration, } if opts.Logger != nil { nextOpts.Logger = opts.Logger @@ -629,3 +638,71 @@ func (tc *TestController) AddClusterControllerMember(t *testing.T, opts *TestCon } return NewTestController(t, nextOpts) } + +// WaitForNextWorkerStatusUpdate waits for the next status check from a worker to +// come in. If it does not come in within the default status grace +// period, this function returns an error. +func (tc *TestController) WaitForNextWorkerStatusUpdate(workerId string) error { + tc.Logger().Debug("waiting for next status report from worker", "worker", workerId) + + if err := tc.waitForNextWorkerStatusUpdate(workerId); err != nil { + tc.Logger().Error("error waiting for next status report from worker", "worker", workerId, "err", err) + return err + } + + tc.Logger().Debug("waiting for next status report from worker received successfully", "worker", workerId) + return nil +} + +func (tc *TestController) waitForNextWorkerStatusUpdate(workerId string) error { + waitStatusStart := time.Now() + ctx, cancel := context.WithTimeout(tc.ctx, tc.b.StatusGracePeriodDuration) + defer cancel() + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case <-time.After(time.Second): + // pass + } + + var waitStatusCurrent time.Time + var err error + tc.Controller().WorkerStatusUpdateTimes().Range(func(k, v interface{}) bool { + if k == nil || v == nil { + err = fmt.Errorf("nil key or value on entry: key=%#v value=%#v", k, v) + return false + } + + workerStatusUpdateId, ok := k.(string) + if !ok { + err = fmt.Errorf("unexpected type %T for key: key=%#v value=%#v", k, k, v) + return false + } + + workerStatusUpdateTime, ok := v.(time.Time) + if !ok { + err = fmt.Errorf("unexpected type %T for value: key=%#v value=%#v", k, k, v) + return false + } + + if workerStatusUpdateId == workerId { + waitStatusCurrent = workerStatusUpdateTime + return false + } + + return true + }) + + if err != nil { + return err + } + + if waitStatusCurrent.Sub(waitStatusStart) > 0 { + break + } + } + + return nil +} diff --git a/internal/servers/options.go b/internal/servers/options.go index 6c2b0cf510..f82f5a52d2 100644 --- a/internal/servers/options.go +++ b/internal/servers/options.go @@ -37,8 +37,9 @@ func WithLimit(limit int) Option { } } -// WithSkipVetForWrite provides an option to allow skipping vet checks to allow -// testing lower-level SQL triggers and constraints +// WithLiveness indicates how far back we want to search for server entries. +// Use 0 for the default liveness (15 seconds). A liveness value of -1 removes +// the liveliness condition. func WithLiveness(liveness time.Duration) Option { return func(o *options) { o.withLiveness = liveness diff --git a/internal/servers/repository.go b/internal/servers/repository.go index 53263c566d..c25a2a7592 100644 --- a/internal/servers/repository.go +++ b/internal/servers/repository.go @@ -13,7 +13,15 @@ import ( ) const ( - defaultLiveness = 15 * time.Second + // DefaultLiveness is the setting that controls the server "liveness" time, + // or the maximum allowable time that a worker can't send a status update to + // the controller for. After this, the server is considered dead, and it will + // be taken out of the rotation for allowable workers for connections, and + // connections will possibly start to be terminated and marked as closed + // depending on the grace period setting (see + // base.Server.StatusGracePeriodDuration). This value serves as the default + // and minimum allowable setting for the grace period. + DefaultLiveness = 15 * time.Second ) type ServerType string @@ -67,18 +75,27 @@ func (r *Repository) listServersWithReader(ctx context.Context, reader db.Reader opts := getOpts(opt...) liveness := opts.withLiveness if liveness == 0 { - liveness = defaultLiveness + liveness = DefaultLiveness } + + var where string + if liveness > 0 { + where = fmt.Sprintf("type = $1 and update_time > now() - interval '%d seconds'", uint32(liveness.Seconds())) + } else { + where = "type = $1" + } + var servers []*Server if err := reader.SearchWhere( ctx, &servers, - fmt.Sprintf("type = $1 and update_time > now() - interval '%d seconds'", uint32(liveness.Seconds())), + where, []interface{}{serverType}, db.WithLimit(-1), ); err != nil { return nil, errors.Wrap(err, "servers.listServersWithReader") } + return servers, nil } diff --git a/internal/servers/repository_test.go b/internal/servers/repository_test.go index f58736f0c3..b36f6d0ef3 100644 --- a/internal/servers/repository_test.go +++ b/internal/servers/repository_test.go @@ -1,12 +1,14 @@ package servers_test import ( + "context" "testing" "time" "github.com/hashicorp/boundary/api/roles" "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/servers" "github.com/hashicorp/boundary/internal/servers/controller" "github.com/hashicorp/boundary/internal/types/scope" @@ -180,3 +182,80 @@ func TestTagUpdatingListing(t *testing.T) { } require.Equal(exp, tags) } + +func TestListServersWithLiveness(t *testing.T) { + t.Parallel() + require := require.New(t) + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + kms := kms.TestKms(t, conn, wrapper) + serversRepo, err := servers.NewRepository(rw, rw, kms) + require.NoError(err) + ctx := context.Background() + + newServer := func(privateId string) *servers.Server { + result := &servers.Server{ + PrivateId: privateId, + Type: "worker", + Address: "127.0.0.1", + } + _, rowsUpdated, err := serversRepo.UpsertServer(ctx, result) + require.NoError(err) + require.Equal(1, rowsUpdated) + + return result + } + + server1 := newServer("test1") + server2 := newServer("test2") + server3 := newServer("test3") + + // Sleep the default liveness time (15sec currently) +1s + time.Sleep(time.Second * 16) + + // Push an upsert to the first worker so that its status has been + // updated. + _, rowsUpdated, err := serversRepo.UpsertServer(ctx, server1) + require.NoError(err) + require.Equal(1, rowsUpdated) + + requireIds := func(expected []string, actual []*servers.Server) { + require.Len(expected, len(actual)) + want := make(map[string]struct{}) + for _, v := range expected { + want[v] = struct{}{} + } + + got := make(map[string]struct{}) + for _, v := range actual { + got[v.PrivateId] = struct{}{} + } + + require.Equal(want, got) + } + + // Default liveness, should only list 1 + result, err := serversRepo.ListServers(ctx, servers.ServerTypeWorker) + require.NoError(err) + require.Len(result, 1) + requireIds([]string{server1.PrivateId}, result) + + // Upsert second server. + _, rowsUpdated, err = serversRepo.UpsertServer(ctx, server2) + require.NoError(err) + require.Equal(1, rowsUpdated) + + // Static liveness. Should get two, so long as this did not take + // more than 5s to execute. + result, err = serversRepo.ListServers(ctx, servers.ServerTypeWorker, servers.WithLiveness(time.Second*5)) + require.NoError(err) + require.Len(result, 2) + requireIds([]string{server1.PrivateId, server2.PrivateId}, result) + + // Liveness disabled, should get all three servers. + result, err = serversRepo.ListServers(ctx, servers.ServerTypeWorker, servers.WithLiveness(-1)) + require.NoError(err) + require.Len(result, 3) + requireIds([]string{server1.PrivateId, server2.PrivateId, server3.PrivateId}, result) +} diff --git a/internal/servers/worker/status.go b/internal/servers/worker/status.go index 821ceac7ba..6a083b699c 100644 --- a/internal/servers/worker/status.go +++ b/internal/servers/worker/status.go @@ -3,8 +3,6 @@ package worker import ( "context" "math/rand" - "os" - "strconv" "time" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" @@ -15,48 +13,10 @@ import ( // In the future we could make this configurable const ( - statusInterval = 2 * time.Second - statusTimeout = 5 * time.Second - defaultStatusGracePeriod = 30 * time.Second - statusGracePeriodEnvVar = "BOUNDARY_STATUS_GRACE_PERIOD" + statusInterval = 2 * time.Second + statusTimeout = 5 * time.Second ) -// statusGracePeriod returns the status grace period setting for this -// worker, in seconds. -// -// The grace period is the length of time we allow connections to run -// on a worker in the event of an error sending status updates. The -// period is defined the length of time since the last successful -// update. -// -// The setting is derived from one of the following: -// -// * BOUNDARY_STATUS_GRACE_PERIOD, if defined, can be set to an -// integer value to define the setting. -// * If this is missing, the default (30 seconds) is used. -// -func (w *Worker) statusGracePeriod() time.Duration { - if v := os.Getenv(statusGracePeriodEnvVar); v != "" { - n, err := strconv.Atoi(v) - if err != nil { - w.logger.Error("could not read setting for BOUNDARY_STATUS_GRACE_PERIOD, using default", - "err", err, - "value", v, - ) - return defaultStatusGracePeriod - } - - if n < 1 { - w.logger.Error("invalid setting for BOUNDARY_STATUS_GRACE_PERIOD, using default", "value", v) - return defaultStatusGracePeriod - } - - return time.Second * time.Duration(n) - } - - return defaultStatusGracePeriod -} - type LastStatusInformation struct { *pbs.StatusResponse StatusTime time.Time @@ -99,6 +59,37 @@ func (w *Worker) LastStatusSuccess() *LastStatusInformation { return w.lastStatusSuccess.Load().(*LastStatusInformation) } +// WaitForNextSuccessfulStatusUpdate waits for the next successful status. It's +// used by testing (and in the future, shutdown) in place of a more opaque and +// possibly unnecessarily long sleep for things like initial controller +// check-in, etc. +// +// The timeout is aligned with the worker's status grace period. A nil error +// means the status was sent successfully. +func (w *Worker) WaitForNextSuccessfulStatusUpdate() error { + w.logger.Debug("waiting for next status report to controller") + waitStatusStart := time.Now() + ctx, cancel := context.WithTimeout(w.baseContext, w.conf.StatusGracePeriodDuration) + defer cancel() + for { + select { + case <-time.After(time.Second): + // pass + + case <-ctx.Done(): + w.logger.Error("error waiting for next status report to controller", "err", ctx.Err()) + return ctx.Err() + } + + if w.lastSuccessfulStatusTime().Sub(waitStatusStart) > 0 { + break + } + } + + w.logger.Debug("next worker status update sent successfully") + return nil +} + func (w *Worker) sendWorkerStatus(cancelCtx context.Context) { // First send info as-is. We'll perform cleanup duties after we // get cancel/job change info back. @@ -196,6 +187,7 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context) { w.lastStatusSuccess.Store(&LastStatusInformation{StatusResponse: result, StatusTime: time.Now()}) for _, request := range result.GetJobsRequests() { + w.logger.Trace("got job request from controller", "request", request) switch request.GetRequestType() { case pbs.CHANGETYPE_CHANGETYPE_UPDATE_STATE: switch request.GetJob().GetType() { @@ -204,12 +196,25 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context) { sessionId := sessInfo.GetSessionId() siRaw, ok := w.sessionInfoMap.Load(sessionId) if !ok { - w.logger.Warn("asked to cancel session but could not find a local information for it", "session_id", sessionId) + w.logger.Warn("session change requested but could not find local information for it", "session_id", sessionId) continue } si := siRaw.(*sessionInfo) si.Lock() si.status = sessInfo.GetStatus() + // Update connection state if there are any connections in + // the request. + for _, conn := range sessInfo.GetConnections() { + connId := conn.GetConnectionId() + connInfo, ok := si.connInfoMap[conn.GetConnectionId()] + if !ok { + w.logger.Warn("connection change requested but could not find local information for it", "connection_id", connId) + continue + } + + connInfo.status = conn.GetStatus() + } + si.Unlock() } } @@ -240,22 +245,31 @@ func (w *Worker) cleanupConnections(cancelCtx context.Context, ignoreSessionStat si.status == pbs.SESSIONSTATUS_SESSIONSTATUS_CANCELING, si.status == pbs.SESSIONSTATUS_SESSIONSTATUS_TERMINATED, time.Until(si.lookupSessionResponse.Expiration.AsTime()) < 0: - var toClose int - for k, v := range si.connInfoMap { - if v.closeTime.IsZero() { - toClose++ - v.connCancel() - w.logger.Info("terminated connection due to cancellation or expiration", "session_id", si.id, "connection_id", k) - closeInfo[k] = si.id - } + // Cancel connections without regard to individual connection + // state. + closedIds := w.cancelConnections(si.connInfoMap, true) + for _, connId := range closedIds { + closeInfo[connId] = si.id + w.logClose(si.id, connId) } + // closeTime is marked by closeConnections iff the // status is returned for that connection as closed. If // the session is no longer valid and all connections // are marked closed, clean up the session. - if toClose == 0 { + if len(closedIds) == 0 { cleanSessionIds = append(cleanSessionIds, si.id) } + + default: + // Cancel connections *with* regard to individual connection + // state (ie: only ones that the controller has requested be + // terminated). + closedIds := w.cancelConnections(si.connInfoMap, false) + for _, connId := range closedIds { + closeInfo[connId] = si.id + w.logClose(si.id, connId) + } } return true @@ -277,6 +291,33 @@ func (w *Worker) cleanupConnections(cancelCtx context.Context, ignoreSessionStat } } +// cancelConnections is run by cleanupConnections to iterate over a +// session's connInfoMap and close connections based on the +// connection's state (or regardless if ignoreConnectionState is +// set). +// +// The returned map and slice are the maps of connection -> session, +// and sessions to completely remove from local state, respectively. +func (w *Worker) cancelConnections(connInfoMap map[string]*connInfo, ignoreConnectionState bool) []string { + var closedIds []string + for k, v := range connInfoMap { + if v.closeTime.IsZero() { + if !ignoreConnectionState && v.status != pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED { + continue + } + + v.connCancel() + closedIds = append(closedIds, k) + } + } + + return closedIds +} + +func (w *Worker) logClose(sessionId, connId string) { + w.logger.Info("terminated connection due to cancellation or expiration", "session_id", sessionId, "connection_id", connId) +} + func (w *Worker) lastSuccessfulStatusTime() time.Time { lastStatus := w.LastStatusSuccess() if lastStatus == nil { @@ -288,7 +329,7 @@ func (w *Worker) lastSuccessfulStatusTime() time.Time { func (w *Worker) isPastGrace() (bool, time.Time, time.Duration) { t := w.lastSuccessfulStatusTime() - u := w.statusGracePeriod() + u := w.conf.StatusGracePeriodDuration v := time.Since(t) return v > u, t, u } diff --git a/internal/servers/worker/status_test.go b/internal/servers/worker/status_test.go new file mode 100644 index 0000000000..bf743f8343 --- /dev/null +++ b/internal/servers/worker/status_test.go @@ -0,0 +1,56 @@ +package worker + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/hashicorp/boundary/internal/cmd/base" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/require" +) + +func TestWorkerWaitForNextSuccessfulStatusUpdate(t *testing.T) { + t.Parallel() + require := require.New(t) + for _, name := range []string{"ok", "timeout"} { + t.Run(name, func(t *testing.T) { + // As-needed initialization of a mock worker + w := &Worker{ + logger: hclog.New(nil), + lastStatusSuccess: new(atomic.Value), + baseContext: context.Background(), + conf: &Config{ + Server: &base.Server{ + StatusGracePeriodDuration: time.Second * 1, + }, + }, + } + + // This is present in New() + w.lastStatusSuccess.Store((*LastStatusInformation)(nil)) + + var wg sync.WaitGroup + var err error + wg.Add(1) + go func() { + err = w.WaitForNextSuccessfulStatusUpdate() + wg.Done() + }() + + if name == "ok" { + time.Sleep(time.Millisecond * 100) + w.lastStatusSuccess.Store(&LastStatusInformation{StatusTime: time.Now()}) + } + + wg.Wait() + if name == "timeout" { + require.ErrorIs(err, context.DeadlineExceeded) + } else { + require.NoError(err) + } + }) + } +} diff --git a/internal/servers/worker/testing.go b/internal/servers/worker/testing.go index ddd2bc1952..fedaa0f397 100644 --- a/internal/servers/worker/testing.go +++ b/internal/servers/worker/testing.go @@ -176,6 +176,10 @@ type TestWorkerOpts struct { // The logger to use, or one will be created Logger hclog.Logger + + // The amount of time to wait before marking connections as closed when a + // connection cannot be made back to the controller + StatusGracePeriodDuration time.Duration } func NewTestWorker(t *testing.T, opts *TestWorkerOpts) *TestWorker { @@ -219,6 +223,9 @@ func NewTestWorker(t *testing.T, opts *TestWorkerOpts) *TestWorker { }) } + // Initialize status grace period + tw.b.SetStatusGracePeriodDuration(opts.StatusGracePeriodDuration) + if opts.Config.Worker == nil { opts.Config.Worker = new(config.Worker) } @@ -278,10 +285,11 @@ func (tw *TestWorker) AddClusterWorkerMember(t *testing.T, opts *TestWorkerOpts) opts = new(TestWorkerOpts) } nextOpts := &TestWorkerOpts{ - WorkerAuthKms: tw.w.conf.WorkerAuthKms, - Name: opts.Name, - InitialControllers: tw.ControllerAddrs(), - Logger: tw.w.conf.Logger, + WorkerAuthKms: tw.w.conf.WorkerAuthKms, + Name: opts.Name, + InitialControllers: tw.ControllerAddrs(), + Logger: tw.w.conf.Logger, + StatusGracePeriodDuration: opts.StatusGracePeriodDuration, } if opts.Logger != nil { nextOpts.Logger = opts.Logger diff --git a/internal/servers/worker/worker.go b/internal/servers/worker/worker.go index dd62e74d4a..944b243421 100644 --- a/internal/servers/worker/worker.go +++ b/internal/servers/worker/worker.go @@ -28,8 +28,8 @@ type Worker struct { started *ua.Bool controllerStatusConn *atomic.Value - workerStartTime time.Time lastStatusSuccess *atomic.Value + workerStartTime time.Time controllerResolver *atomic.Value diff --git a/internal/session/query.go b/internal/session/query.go index c294185c41..563f3a5e69 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -235,13 +235,13 @@ where ) ` - // connectionsToCloseCte finds connections that are: + // closeDeadConnectionsCte finds connections that are: // // * not closed // * not announced by a given server in its latest update // // and marks them as closed. - connectionsToCloseCte = ` + closeDeadConnectionsCte = ` with -- Find connections that are not closed so we can reference those IDs unclosed_connections as ( @@ -257,7 +257,7 @@ with -- It's not in limbo between when it moved into this state and when -- it started being reported by the worker, which is roughly every -- 2-3 seconds - start_time < now() - interval '10 seconds' + start_time < wt_sub_seconds_from_now(10) ), connections_to_close as ( select public_id @@ -278,5 +278,100 @@ with where public_id in (select public_id from connections_to_close) returning public_id + ` + + // closeConnectionsForDeadServersCte finds connections that are: + // + // * not closed + // * belong to servers that have not reported in within an acceptable + // threshold of time + // + // and marks them as closed. + // + // The query returns the set of servers that have had connections closed + // along with their last update time and the number of connections closed on + // each. + closeConnectionsForDeadServersCte = ` +with + -- Get dead servers, parameterized off of grace period in seconds + dead_servers as ( + select private_id, update_time + from server + where update_time < wt_sub_seconds_from_now($1) + ), + -- Find connections that are not closed so we can reference those IDs + unclosed_connections as ( + select connection_id + from session_connection_state + where + -- It's the current state + end_time is null + and + -- Current state isn't closed state + state in ('authorized', 'connected') + and + -- It's not in limbo between when it moved into this state and when + -- it started being reported by the worker, which is roughly every + -- 2-3 seconds + start_time < wt_sub_seconds_from_now(10) + ), + connections_to_close as ( + select public_id + from session_connection + where + -- Related to the worker that just reported to us + server_id in (select private_id from dead_servers) + and + -- Only unclosed ones + public_id in (select connection_id from unclosed_connections) + ), + closed_connections as ( + update session_connection + set + closed_reason = 'system error' + where + public_id in (select public_id from connections_to_close) + returning public_id, server_id + ) + select + dead_servers.private_id, + dead_servers.update_time, + count(closed_connections.public_id) + from dead_servers + left join closed_connections + on dead_servers.private_id = closed_connections.server_id + group by dead_servers.private_id, dead_servers.update_time + having count(closed_connections.public_id) > 0 + order by dead_servers.private_id + ` + + // shouldCloseConnectionsCte finds connections that are marked as closed in + // the database given a set of connection IDs. They are returned along with + // their associated session ID. + // + // The second parameter is a set of session IDs that we have already + // submitted a session-wide close request for, so sending another change + // request for them would be redundant. + shouldCloseConnectionsCte = ` +with + -- Find connections that are closed so we can reference those IDs + closed_connections as ( + select connection_id + from session_connection_state + where + -- It's the current state + end_time is null + and + -- Current state is closed + state = 'closed' + and + connection_id in (%s) + ) + select public_id, session_id + from session_connection + where + public_id in (select connection_id from closed_connections) + -- Below fmt arg is filled in if there are session IDs to filter against + %s ` ) diff --git a/internal/session/repository_connection.go b/internal/session/repository_connection.go index c145c14baa..80ff88ef3c 100644 --- a/internal/session/repository_connection.go +++ b/internal/session/repository_connection.go @@ -4,11 +4,18 @@ import ( "context" "fmt" "strings" + "time" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/servers" ) +// deadWorkerConnCloseMinGrace is the minimum allowable setting for +// the CloseConnectionsForDeadWorkers method. This is synced with +// the default server liveness setting. +var DeadWorkerConnCloseMinGrace = int(servers.DefaultLiveness.Seconds()) + // LookupConnection will look up a connection in the repository and return the connection // with its states. If the connection is not found, it will return nil, nil, nil. // No options are currently supported. @@ -99,18 +106,16 @@ func (r *Repository) DeleteConnection(ctx context.Context, publicId string, _ .. return rowsDeleted, nil } -// CloseDeadConnectionsOnWorkerReport will run the connectionsToClose CTE to -// look for connections that should be marked closed because they are no longer -// claimed by a server. This does notdetect connections where the server is no -// longer reporting status; that's for a different CTE to be called by a -// heartbeat detector. +// CloseDeadConnectionsForWorker will run closeDeadConnectionsCte to look for +// connections that should be marked closed because they are no longer claimed +// by a server. // // The foundConns input should be the currently-claimed connections; the CTE // uses a NOT IN clause to ensure these are excluded. It is not an error for // this to be empty as the worker could claim no connections; in that case all // connections will immediately transition to closed. -func (r *Repository) CloseDeadConnectionsOnWorkerReport(ctx context.Context, serverId string, foundConns []string) (int, error) { - const op = "session.(Repository).CloseDeadConnectionsOnWorkerReport" +func (r *Repository) CloseDeadConnectionsForWorker(ctx context.Context, serverId string, foundConns []string) (int, error) { + const op = "session.(Repository).CloseDeadConnectionsForWorker" if serverId == "" { return db.NoRowsAffected, errors.New(errors.InvalidParameter, op, "missing server id") } @@ -135,7 +140,7 @@ func (r *Repository) CloseDeadConnectionsOnWorkerReport(ctx context.Context, ser db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { var err error - rowsAffected, err = w.Exec(ctx, fmt.Sprintf(connectionsToCloseCte, publicIdStr), args) + rowsAffected, err = w.Exec(ctx, fmt.Sprintf(closeDeadConnectionsCte, publicIdStr), args) if err != nil { return errors.Wrap(err, op) } @@ -148,6 +153,129 @@ func (r *Repository) CloseDeadConnectionsOnWorkerReport(ctx context.Context, ser return rowsAffected, nil } +type CloseConnectionsForDeadWorkersResult struct { + ServerId string + LastUpdateTime time.Time + NumberConnectionsClosed int +} + +// CloseConnectionsForDeadWorkers will run +// closeConnectionsForDeadServersCte to look for connections that +// should be marked because they are on a server that is no longer +// sending status updates to the controller(s). +// +// The only input to the method is the grace period, in seconds. +func (r *Repository) CloseConnectionsForDeadWorkers(ctx context.Context, gracePeriod int) ([]CloseConnectionsForDeadWorkersResult, error) { + const op = "session.(Repository).CloseConnectionsForDeadWorkers" + if gracePeriod < DeadWorkerConnCloseMinGrace { + return nil, errors.New( + errors.InvalidParameter, op, fmt.Sprintf("gracePeriod must be at least %d seconds", DeadWorkerConnCloseMinGrace)) + } + + results := make([]CloseConnectionsForDeadWorkersResult, 0) + _, err := r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + rows, err := w.Query(ctx, closeConnectionsForDeadServersCte, []interface{}{gracePeriod}) + if err != nil { + return errors.Wrap(err, op) + } + defer rows.Close() + + for rows.Next() { + var ( + serverId string + lastUpdateTime time.Time + numberConnectionsClosed int + ) + + if err := rows.Scan(&serverId, &lastUpdateTime, &numberConnectionsClosed); err != nil { + return errors.Wrap(err, op) + } + + results = append(results, CloseConnectionsForDeadWorkersResult{ + ServerId: serverId, + LastUpdateTime: lastUpdateTime, + NumberConnectionsClosed: numberConnectionsClosed, + }) + } + + return nil + }, + ) + + if err != nil { + return nil, errors.Wrap(err, op) + } + + return results, nil +} + +// ShouldCloseConnectionsOnWorker will run shouldCloseConnectionsCte to look +// for connections that the worker should close because they are currently +// reporting them as open incorrectly. +// +// The foundConns input here is used to filter closed connection states. This +// is further filtered against the filterSessions input, which is expected to +// be a set of sessions we've already submitted close requests for, so adding +// them again would be redundant. +// +// The returned map[string][]string is indexed by session ID. +func (r *Repository) ShouldCloseConnectionsOnWorker(ctx context.Context, foundConns, filterSessions []string) (map[string][]string, error) { + const op = "session.(Repository).ShouldCloseConnectionsOnWorker" + if len(foundConns) < 1 { + return nil, nil // nothing to do + } + + args := make([]interface{}, 0, len(foundConns)+len(filterSessions)) + + // foundConns first + connsParams := make([]string, len(foundConns)) + for i, connId := range foundConns { + connsParams[i] = fmt.Sprintf("$%d", i+1) + args = append(args, connId) + } + connsStr := strings.Join(connsParams, ",") + + // then filterSessions + var sessionsStr string + if len(filterSessions) > 0 { + offset := len(foundConns) + 1 + sessionsParams := make([]string, len(filterSessions)) + for i, sessionId := range filterSessions { + sessionsParams[i] = fmt.Sprintf("$%d", i+offset) + args = append(args, sessionId) + } + + const sessionIdFmtStr = `and session_id not in (%s)` + sessionsStr = fmt.Sprintf(sessionIdFmtStr, strings.Join(sessionsParams, ",")) + } + + rows, err := r.reader.Query( + ctx, + fmt.Sprintf(shouldCloseConnectionsCte, connsStr, sessionsStr), + args, + ) + if err != nil { + return nil, errors.Wrap(err, op) + } + defer rows.Close() + + result := make(map[string][]string) + for rows.Next() { + var connectionId, sessionId string + if err := rows.Scan(&connectionId, &sessionId); err != nil { + return nil, errors.Wrap(err, op) + } + + result[sessionId] = append(result[sessionId], connectionId) + } + + return result, nil +} + func fetchConnectionStates(ctx context.Context, r db.Reader, connectionId string, opt ...db.Option) ([]*ConnectionState, error) { const op = "session.fetchConnectionStates" var states []*ConnectionState diff --git a/internal/session/repository_connection_test.go b/internal/session/repository_connection_test.go index 18454793a5..b0498d4361 100644 --- a/internal/session/repository_connection_test.go +++ b/internal/session/repository_connection_test.go @@ -2,6 +2,7 @@ package session import ( "context" + "fmt" "testing" "time" @@ -10,6 +11,7 @@ import ( "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" + "github.com/hashicorp/boundary/internal/servers" "github.com/hashicorp/shared-secure-libs/strutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -210,7 +212,7 @@ func TestRepository_DeleteConnection(t *testing.T) { } } -func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) { +func TestRepository_CloseDeadConnectionsOnWorker(t *testing.T) { t.Parallel() require, assert := require.New(t), assert.New(t) conn, _ := db.TestSetup(t, "postgres") @@ -284,7 +286,7 @@ func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) { // all connection IDs for worker 1 should be showing as non-closed, and // the ones for worker 2 not advertised should be closed. shouldStayOpen := worker2ConnIds[0:2] - count, err := repo.CloseDeadConnectionsOnWorkerReport(ctx, worker2.GetPrivateId(), shouldStayOpen) + count, err := repo.CloseDeadConnectionsForWorker(ctx, worker2.GetPrivateId(), shouldStayOpen) require.NoError(err) assert.Equal(4, count) @@ -310,7 +312,7 @@ func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) { // Now, advertise none of the connection IDs for worker 2. This is mainly to // test that handling the case where we do not include IDs works properly as // it changes the where clause. - count, err = repo.CloseDeadConnectionsOnWorkerReport(ctx, worker1.GetPrivateId(), nil) + count, err = repo.CloseDeadConnectionsForWorker(ctx, worker1.GetPrivateId(), nil) require.NoError(err) assert.Equal(6, count) @@ -330,3 +332,440 @@ func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) { assert.True(foundClosed != strutil.StrListContains(shouldStayOpen, conn.PublicId)) } } + +func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { + t.Parallel() + require := require.New(t) + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(err) + serversRepo, err := servers.NewRepository(rw, rw, kms) + require.NoError(err) + ctx := context.Background() + + // connection count = 6 * states(authorized, connected, closed = 3) * servers_with_open_connections(3) + numConns := 54 + + // Create four "workers". This is similar to the setup in + // TestRepository_CloseDeadConnectionsOnWorker, but a bit more complex; + // firstly, the last worker will have no connections at all, and we will be + // closing the others in stages to test multiple servers being closed at + // once. + worker1 := TestWorker(t, conn, wrapper, WithServerId("worker1")) + worker2 := TestWorker(t, conn, wrapper, WithServerId("worker2")) + worker3 := TestWorker(t, conn, wrapper, WithServerId("worker3")) + worker4 := TestWorker(t, conn, wrapper, WithServerId("worker4")) + + // Create sessions on the first three, activate, and authorize connections + var worker1ConnIds, worker2ConnIds, worker3ConnIds []string + for i := 0; i < numConns; i++ { + var serverId string + if i%3 == 0 { + serverId = worker1.PrivateId + } else if i%3 == 1 { + serverId = worker2.PrivateId + } else { + serverId = worker3.PrivateId + } + sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithServerId(serverId), WithDbOpts(db.WithSkipVetForWrite(true))) + sess, _, err = repo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, serverId, "worker", []byte("foo")) + require.NoError(err) + c, cs, _, err := repo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) + require.NoError(err) + require.Len(cs, 1) + require.Equal(StatusAuthorized, cs[0].Status) + if i%3 == 0 { + worker1ConnIds = append(worker1ConnIds, c.GetPublicId()) + } else if i%3 == 1 { + worker2ConnIds = append(worker2ConnIds, c.GetPublicId()) + } else { + worker3ConnIds = append(worker3ConnIds, c.GetPublicId()) + } + } + + // Mark a third of the connections connected, a third closed, and leave the + // others authorized. This is just to ensure we have a spread when we test it + // out. + for i, connId := range func() []string { + var s []string + s = append(s, worker1ConnIds...) + s = append(s, worker2ConnIds...) + s = append(s, worker3ConnIds...) + return s + }() { + if i%3 == 0 { + _, cs, err := repo.ConnectConnection(ctx, ConnectWith{ + ConnectionId: connId, + ClientTcpAddress: "127.0.0.1", + ClientTcpPort: 22, + EndpointTcpAddress: "127.0.0.1", + EndpointTcpPort: 22, + }) + require.NoError(err) + require.Len(cs, 2) + var foundAuthorized, foundConnected bool + for _, status := range cs { + if status.Status == StatusAuthorized { + foundAuthorized = true + } + if status.Status == StatusConnected { + foundConnected = true + } + } + require.True(foundAuthorized) + require.True(foundConnected) + } else if i%3 == 1 { + resp, err := repo.CloseConnections(ctx, []CloseWith{ + { + ConnectionId: connId, + ClosedReason: ConnectionCanceled, + }, + }) + require.NoError(err) + require.Len(resp, 1) + cs := resp[0].ConnectionStates + require.Len(cs, 2) + var foundAuthorized, foundClosed bool + for _, status := range cs { + if status.Status == StatusAuthorized { + foundAuthorized = true + } + if status.Status == StatusClosed { + foundClosed = true + } + } + require.True(foundAuthorized) + require.True(foundClosed) + } + } + + // There is a 15 second delay to account for time for the connections to + // transition + time.Sleep(15 * time.Second) + + // updateServer is a helper for updating the update time for our + // servers. The controller is read back so that we can reference + // the most up-to-date fields. + updateServer := func(t *testing.T, w *servers.Server) *servers.Server { + t.Helper() + _, rowsUpdated, err := serversRepo.UpsertServer(ctx, w) + require.NoError(err) + require.Equal(1, rowsUpdated) + servers, err := serversRepo.ListServers(ctx, servers.ServerTypeWorker) + require.NoError(err) + for _, server := range servers { + if server.PrivateId == w.PrivateId { + return server + } + } + + require.FailNowf("server %q not found after updating", w.PrivateId) + // Looks weird but needed to build, as we fail in testify instead + // of returning an error + return nil + } + + // requireConnectionStatus is a helper expecting all connections on a worker + // to be closed. + requireConnectionStatus := func(t *testing.T, connIds []string, expectAllClosed bool) { + t.Helper() + + var conns []*Connection + require.NoError(repo.list(ctx, &conns, "", nil)) + for i, connId := range connIds { + var expected ConnectionStatus + switch { + case expectAllClosed: + expected = StatusClosed + + case i%3 == 0: + expected = StatusConnected + + case i%3 == 1: + expected = StatusClosed + + case i%3 == 2: + expected = StatusAuthorized + } + + _, states, err := repo.LookupConnection(ctx, connId) + require.NoError(err) + require.Equal(expected, states[0].Status, "expected latest status for %q (index %d) to be %v", connId, i, expected) + } + } + + // We need this helper to fix the zone on protobuf timestamps + // versus what gets reported in the + // CloseConnectionsForDeadWorkersResult. + timestampPbAsUTC := func(t *testing.T, tm time.Time) time.Time { + t.Helper() + utcLoc, err := time.LoadLocation("Etc/UTC") + require.NoError(err) + return tm.In(utcLoc) + } + + // Now try some scenarios. + { + // First, test the error/validation case. + result, err := repo.CloseConnectionsForDeadWorkers(ctx, 0) + require.Equal(err, errors.E( + errors.WithCode(errors.InvalidParameter), + errors.WithOp("session.(Repository).CloseConnectionsForDeadWorkers"), + errors.WithMsg(fmt.Sprintf("gracePeriod must be at least %d seconds", DeadWorkerConnCloseMinGrace)), + )) + require.Nil(result) + } + + { + // Now, try the basis, or where all workers are reporting in. + worker1 = updateServer(t, worker1) + worker2 = updateServer(t, worker2) + worker3 = updateServer(t, worker3) + updateServer(t, worker4) // no re-assignment here because we never reference the server again + + result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace) + require.NoError(err) + require.Empty(result) + // Expect appropriate split connection state on worker1 + requireConnectionStatus(t, worker1ConnIds, false) + // Expect appropriate split connection state on worker2 + requireConnectionStatus(t, worker2ConnIds, false) + // Expect appropriate split connection state on worker3 + requireConnectionStatus(t, worker3ConnIds, false) + } + + { + // Now try a zero case - similar to the basis, but only in that no results + // are expected to be returned for workers with no connections, even if + // they are dead. Here, the server with no connections is worker #4. + time.Sleep(time.Second * time.Duration(DeadWorkerConnCloseMinGrace)) + worker1 = updateServer(t, worker1) + worker2 = updateServer(t, worker2) + worker3 = updateServer(t, worker3) + + result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace) + require.NoError(err) + require.Empty(result) + // Expect appropriate split connection state on worker1 + requireConnectionStatus(t, worker1ConnIds, false) + // Expect appropriate split connection state on worker2 + requireConnectionStatus(t, worker2ConnIds, false) + // Expect appropriate split connection state on worker3 + requireConnectionStatus(t, worker3ConnIds, false) + } + + { + // The first induction is letting the first worker "die" by not updating it + // too. All of its authorized and connected connections should be dead. + time.Sleep(time.Second * time.Duration(DeadWorkerConnCloseMinGrace)) + worker2 = updateServer(t, worker2) + worker3 = updateServer(t, worker3) + + result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace) + require.NoError(err) + // Assert that we have one result with the appropriate ID and + // number of connections closed. Due to how things are + require.Equal([]CloseConnectionsForDeadWorkersResult{ + { + ServerId: worker1.PrivateId, + LastUpdateTime: timestampPbAsUTC(t, worker1.UpdateTime.AsTime()), + NumberConnectionsClosed: 12, // 18 per server, with 6 closed already + }, + }, result) + // Expect all connections closed on worker1 + requireConnectionStatus(t, worker1ConnIds, true) + // Expect appropriate split connection state on worker2 + requireConnectionStatus(t, worker2ConnIds, false) + // Expect appropriate split connection state on worker3 + requireConnectionStatus(t, worker3ConnIds, false) + } + + { + // The final case is having the other two workers die. After + // this, we should have all connections closed with the + // appropriate message from the next two servers acted on. + time.Sleep(time.Second * time.Duration(DeadWorkerConnCloseMinGrace)) + + result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace) + require.NoError(err) + // Assert that we have one result with the appropriate ID and number of connections closed. + require.Equal([]CloseConnectionsForDeadWorkersResult{ + { + ServerId: worker2.PrivateId, + LastUpdateTime: timestampPbAsUTC(t, worker2.UpdateTime.AsTime()), + NumberConnectionsClosed: 12, // 18 per server, with 6 closed already + }, + { + ServerId: worker3.PrivateId, + LastUpdateTime: timestampPbAsUTC(t, worker3.UpdateTime.AsTime()), + NumberConnectionsClosed: 12, // 18 per server, with 6 closed already + }, + }, result) + // Expect all connections closed on worker1 + requireConnectionStatus(t, worker1ConnIds, true) + // Expect all connections closed on worker2 + requireConnectionStatus(t, worker2ConnIds, true) + // Expect all connections closed on worker3 + requireConnectionStatus(t, worker3ConnIds, true) + } +} + +func TestRepository_ShouldCloseConnectionsOnWorker(t *testing.T) { + t.Parallel() + require := require.New(t) + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(err) + ctx := context.Background() + numConns := 12 + + // Create a worker, we only need one here as our query is dependent + // on connection and not worker. + worker1 := TestWorker(t, conn, wrapper, WithServerId("worker1")) + + // Create a few sessions on each, activate, and authorize a connection + var connIds []string + sessionConnIds := make(map[string][]string) + for i := 0; i < numConns; i++ { + serverId := worker1.PrivateId + sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithServerId(serverId), WithDbOpts(db.WithSkipVetForWrite(true))) + sessionId := sess.GetPublicId() + sess, _, err = repo.ActivateSession(ctx, sessionId, sess.Version, serverId, "worker", []byte("foo")) + require.NoError(err) + c, cs, _, err := repo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) + require.NoError(err) + require.Len(cs, 1) + require.Equal(StatusAuthorized, cs[0].Status) + connId := c.GetPublicId() + connIds = append(connIds, connId) + sessionConnIds[sessionId] = append(sessionConnIds[sessionId], connId) + } + + // Mark half of the connections connected, close the other half. + for i, connId := range connIds { + if i%2 == 0 { + _, cs, err := repo.ConnectConnection(ctx, ConnectWith{ + ConnectionId: connId, + ClientTcpAddress: "127.0.0.1", + ClientTcpPort: 22, + EndpointTcpAddress: "127.0.0.1", + EndpointTcpPort: 22, + }) + require.NoError(err) + require.Len(cs, 2) + var foundAuthorized, foundConnected bool + for _, status := range cs { + if status.Status == StatusAuthorized { + foundAuthorized = true + } + if status.Status == StatusConnected { + foundConnected = true + } + } + require.True(foundAuthorized) + require.True(foundConnected) + } else { + resp, err := repo.CloseConnections(ctx, []CloseWith{ + { + ConnectionId: connId, + ClosedReason: ConnectionCanceled, + }, + }) + require.NoError(err) + require.Len(resp, 1) + cs := resp[0].ConnectionStates + require.Len(cs, 2) + var foundAuthorized, foundClosed bool + for _, status := range cs { + if status.Status == StatusAuthorized { + foundAuthorized = true + } + if status.Status == StatusClosed { + foundClosed = true + } + } + require.True(foundAuthorized) + require.True(foundClosed) + } + } + + // There is a 10 second delay to account for time for the connections to + // transition + time.Sleep(15 * time.Second) + + // Now we try some scenarios. + { + // First test an empty set. + result, err := repo.ShouldCloseConnectionsOnWorker(ctx, nil, nil) + require.NoError(err) + require.Zero(result, "should be empty when no connections are supplied") + } + + { + // Here we pass in all of our connections without a filter on + // session. This should return half of the connections - the ones + // that we marked as closed. + // + // Create a copy of our session map with the sessions that have + // closed connections taken out. + expectedSessionConnIds := make(map[string][]string) + for sessionId, connIds := range sessionConnIds { + for _, connId := range connIds { + if testIsConnectionClosed(ctx, t, repo, connId) { + expectedSessionConnIds[sessionId] = append(expectedSessionConnIds[sessionId], connId) + } + } + } + + // Send query, use all connections w/o a filter on sessions. + actualSessionConnIds, err := repo.ShouldCloseConnectionsOnWorker(ctx, connIds, nil) + require.NoError(err) + require.Equal(expectedSessionConnIds, actualSessionConnIds) + } + + { + // Finally, add a session filter. We do this by just alternating + // the session IDs we want to filter on. + expectedSessionConnIds := make(map[string][]string) + var filterSessionIds []string + var filterSession bool + for sessionId, connIds := range sessionConnIds { + for _, connId := range connIds { + if testIsConnectionClosed(ctx, t, repo, connId) { + if !filterSession { + expectedSessionConnIds[sessionId] = append(expectedSessionConnIds[sessionId], connId) + } else { + filterSessionIds = append(filterSessionIds, sessionId) + } + + // Toggle filterSession here (instead of just outer session + // loop) so that we aren't just lining up on + // connected/disconnected connections. + filterSession = !filterSession + } + } + } + + // Send query with the session filter. + actualSessionConnIds, err := repo.ShouldCloseConnectionsOnWorker(ctx, connIds, filterSessionIds) + require.NoError(err) + require.Equal(expectedSessionConnIds, actualSessionConnIds) + } +} + +func testIsConnectionClosed(ctx context.Context, t *testing.T, repo *Repository, connId string) bool { + require := require.New(t) + _, states, err := repo.LookupConnection(ctx, connId) + require.NoError(err) + // Use first state as this LookupConnections returns ordered by + // start time, descending. + return states[0].Status == StatusClosed +} diff --git a/internal/tests/cluster/session_cleanup_test.go b/internal/tests/cluster/session_cleanup_test.go index ab506a2df1..54140d93a1 100644 --- a/internal/tests/cluster/session_cleanup_test.go +++ b/internal/tests/cluster/session_cleanup_test.go @@ -37,9 +37,9 @@ import ( ) const ( - testSendRecvSendMax = 60 - defaultGracePeriod = time.Second * 30 - expectConnectionStateOnWorkerTimeout = defaultGracePeriod * 2 + defaultGracePeriod = time.Second * 15 + expectConnectionStateOnControllerTimeout = time.Minute * 2 + expectConnectionStateOnWorkerTimeout = defaultGracePeriod * 2 // This is the interval that we check states on in the worker. It // needs to be particularly granular to ensure that we allow for @@ -55,227 +55,414 @@ const ( expectConnectionStateOnWorkerInterval = time.Millisecond * 100 ) -func TestWorkerSessionCleanup(t *testing.T) { - require := require.New(t) - logger := hclog.New(&hclog.LoggerOptions{ - Level: hclog.Trace, - }) - - conf, err := config.DevController() - require.NoError(err) - - pl, err := net.Listen("tcp", "localhost:0") - require.NoError(err) - c1 := controller.NewTestController(t, &controller.TestControllerOpts{ - Config: conf, - InitialResourcesSuffix: "1234567890", - Logger: logger.Named("c1"), - PublicClusterAddr: pl.Addr().String(), - }) - defer c1.Shutdown() - - expectWorkers(t, c1) - - // Wire up the testing proxies - require.Len(c1.ClusterAddrs(), 1) - proxy, err := dawdle.NewProxy("tcp", "", c1.ClusterAddrs()[0], - dawdle.WithListener(pl), - dawdle.WithRbufSize(512), - dawdle.WithWbufSize(512), - ) - require.NoError(err) - defer proxy.Close() - require.NotEmpty(t, proxy.ListenerAddr()) - - w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{ - WorkerAuthKms: c1.Config().WorkerAuthKms, - InitialControllers: []string{proxy.ListenerAddr()}, - Logger: logger.Named("w1"), - }) - defer w1.Shutdown() +// timeoutBurdenType details our "burden cases" for the session +// cleanup tests. +// +// There are 3 burden cases: +// +// * default: This case simulates normal default operation where both +// worker and controller generally are timing out connections at +// generally the same interval. In reality, this is not necessarily +// going to be the case, but it's hard to test individual cases when +// both settings are the same. +// +// * worker: This case assumes the worker is the source of truth for +// controller state. Here, the controller's grace period is +// increased to a high factor over the default to ensure that the +// worker is managing the lifecycle of a connection and will properly +// unclaim it closed once the connection resumes, ensuring the +// connection is marked as closed on the worker. +// +// * controller: Here, the controller is the one doing the work. The +// connection will be open on the worker until status checks resume +// from the worker. At this point, the controller will request the +// status change on the worker, physically closing the connection +// there. +type timeoutBurdenType string - time.Sleep(10 * time.Second) - expectWorkers(t, c1, w1) +const ( + timeoutBurdenTypeDefault timeoutBurdenType = "default" + timeoutBurdenTypeWorker timeoutBurdenType = "worker" + timeoutBurdenTypeController timeoutBurdenType = "controller" +) - // Use an independent context for test things that take a context so - // that we aren't tied to any timeouts in the controller, etc. This - // can interfere with some of the test operations. - ctx := context.Background() +var timeoutBurdenCases = []timeoutBurdenType{timeoutBurdenTypeDefault, timeoutBurdenTypeWorker, timeoutBurdenTypeController} - // Connect target - client := c1.Client() - client.SetToken(c1.Token().Token) - tcl := targets.NewClient(client) - tgt, err := tcl.Read(ctx, "ttcp_1234567890") - require.NoError(err) - require.NotNil(tgt) +func controllerGracePeriod(ty timeoutBurdenType) time.Duration { + if ty == timeoutBurdenTypeWorker { + return defaultGracePeriod * 10 + } - // Create test server, update default port on target - ts := newTestTcpServer(t, logger) - require.NotNil(t, ts) - defer ts.Close() - tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port())) - require.NoError(err) - require.NotNil(tgt) + return defaultGracePeriod +} - // Authorize and connect - sess := newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") - sConn := sess.Connect(ctx, t, logger) +func workerGracePeriod(ty timeoutBurdenType) time.Duration { + if ty == timeoutBurdenTypeController { + return defaultGracePeriod * 10 + } - // Run initial send/receive test, make sure things are working - sConn.TestSendRecvAll(t) + return defaultGracePeriod +} - // Kill the link - proxy.Pause() +// TestWorkerSessionCleanup is the main test for session cleanup, and +// dispatches to the individual subtests. +func TestWorkerSessionCleanup(t *testing.T) { + t.Parallel() + for _, burdenCase := range timeoutBurdenCases { + burdenCase := burdenCase + t.Run(string(burdenCase), func(t *testing.T) { + t.Parallel() + t.Run("single_controller", testWorkerSessionCleanupSingle(burdenCase)) + t.Run("multi_controller", testWorkerSessionCleanupMulti(burdenCase)) + }) + } +} - // Run again, ensure connection is dead - sConn.TestSendRecvFail(t) +func testWorkerSessionCleanupSingle(burdenCase timeoutBurdenType) func(t *testing.T) { + return func(t *testing.T) { + t.Parallel() + require := require.New(t) + logger := hclog.New(&hclog.LoggerOptions{ + Name: t.Name(), + Level: hclog.Trace, + }) + + conf, err := config.DevController() + require.NoError(err) + + pl, err := net.Listen("tcp", "localhost:0") + require.NoError(err) + c1 := controller.NewTestController(t, &controller.TestControllerOpts{ + Config: conf, + InitialResourcesSuffix: "1234567890", + Logger: logger.Named("c1"), + PublicClusterAddr: pl.Addr().String(), + StatusGracePeriodDuration: controllerGracePeriod(burdenCase), + }) + defer c1.Shutdown() + + expectWorkers(t, c1) + + // Wire up the testing proxies + require.Len(c1.ClusterAddrs(), 1) + proxy, err := dawdle.NewProxy("tcp", "", c1.ClusterAddrs()[0], + dawdle.WithListener(pl), + dawdle.WithRbufSize(256), + dawdle.WithWbufSize(256), + ) + require.NoError(err) + defer proxy.Close() + require.NotEmpty(t, proxy.ListenerAddr()) + + w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{ + WorkerAuthKms: c1.Config().WorkerAuthKms, + InitialControllers: []string{proxy.ListenerAddr()}, + Logger: logger.Named("w1"), + StatusGracePeriodDuration: workerGracePeriod(burdenCase), + }) + defer w1.Shutdown() + + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + require.NoError(err) + err = c1.WaitForNextWorkerStatusUpdate(w1.Name()) + require.NoError(err) + expectWorkers(t, c1, w1) + + // Use an independent context for test things that take a context so + // that we aren't tied to any timeouts in the controller, etc. This + // can interfere with some of the test operations. + ctx := context.Background() + + // Connect target + client := c1.Client() + client.SetToken(c1.Token().Token) + tcl := targets.NewClient(client) + tgt, err := tcl.Read(ctx, "ttcp_1234567890") + require.NoError(err) + require.NotNil(tgt) + + // Create test server, update default port on target + ts := newTestTcpServer(t, logger) + require.NotNil(t, ts) + defer ts.Close() + tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port()), targets.WithSessionConnectionLimit(-1)) + require.NoError(err) + require.NotNil(tgt) + + // Authorize and connect + sess := newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") + sConn := sess.Connect(ctx, t) + + // Run initial send/receive test, make sure things are working + logger.Debug("running initial send/recv test") + sConn.TestSendRecvAll(t) + + // Kill the link + logger.Debug("pausing controller/worker link") + proxy.Pause() + + // Wait for failure connection state (depends on burden case) + switch burdenCase { + case timeoutBurdenTypeWorker: + // Wait on worker, then check controller + sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusConnected) + + case timeoutBurdenTypeController: + // Wait on controller, then check worker + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed) + sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusConnected) + + default: + // Should be closed on both worker and controller. Wait on + // worker then check controller. + sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed) + } - // Assert we have no connections left (should be default behavior) - sess.TestNoConnectionsLeft(t) + // Run send/receive test again to check expected connection-level + // behavior + if burdenCase == timeoutBurdenTypeController { + // Burden on controller, should be successful until connection + // resumes + sConn.TestSendRecvAll(t) + } else { + // Connection should die in other cases + sConn.TestSendRecvFail(t) + } - // Assert connection has been removed from the local worker state - sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) + // Resume the connection, and reconnect. + logger.Debug("resuming controller/worker link") + proxy.Resume() + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + require.NoError(err) + + // Do something post-reconnect depending on burden case. Note in + // the default case, both worker and controller should be + // relatively in sync, so we don't worry about these + // post-reconnection assertions. + switch burdenCase { + case timeoutBurdenTypeWorker: + // If we are expecting the worker to be the source of truth of + // a connection status, ensure that our old session's + // connections are actually closed now that the worker is + // properly reporting in again. + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed) + + case timeoutBurdenTypeController: + // If we are expecting the controller to be the source of + // truth, the connection should now be forcibly closed after + // the worker gets a status change request back. + sConn.TestSendRecvFail(t) + } - // Resume the connection, and reconnect. - proxy.Resume() - time.Sleep(time.Second * 10) // Sleep to wait for worker to report back as healthy - sess = newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") // re-assign, other connection will close in t.Cleanup() - sConn = sess.Connect(ctx, t, logger) - sConn.TestSendRecvAll(t) + // Proceed with new connection test + logger.Debug("connecting to new session after resuming controller/worker link") + sess = newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") // re-assign, other connection will close in t.Cleanup() + sConn = sess.Connect(ctx, t) + sConn.TestSendRecvAll(t) + } } -func TestWorkerSessionCleanupMultiController(t *testing.T) { - require := require.New(t) - logger := hclog.New(&hclog.LoggerOptions{ - Level: hclog.Trace, - }) - - // ****************** - // ** Controller 1 ** - // ****************** - conf1, err := config.DevController() - require.NoError(err) +func testWorkerSessionCleanupMulti(burdenCase timeoutBurdenType) func(t *testing.T) { + return func(t *testing.T) { + t.Parallel() + require := require.New(t) + logger := hclog.New(&hclog.LoggerOptions{ + Name: t.Name(), + Level: hclog.Trace, + }) + + // ****************** + // ** Controller 1 ** + // ****************** + conf1, err := config.DevController() + require.NoError(err) + + pl1, err := net.Listen("tcp", "localhost:0") + require.NoError(err) + c1 := controller.NewTestController(t, &controller.TestControllerOpts{ + Config: conf1, + InitialResourcesSuffix: "1234567890", + Logger: logger.Named("c1"), + PublicClusterAddr: pl1.Addr().String(), + StatusGracePeriodDuration: controllerGracePeriod(burdenCase), + }) + defer c1.Shutdown() + + // ****************** + // ** Controller 2 ** + // ****************** + pl2, err := net.Listen("tcp", "localhost:0") + require.NoError(err) + c2 := c1.AddClusterControllerMember(t, &controller.TestControllerOpts{ + Logger: logger.Named("c2"), + PublicClusterAddr: pl2.Addr().String(), + StatusGracePeriodDuration: controllerGracePeriod(burdenCase), + }) + defer c2.Shutdown() + expectWorkers(t, c1) + expectWorkers(t, c2) + + // ************* + // ** Proxy 1 ** + // ************* + require.Len(c1.ClusterAddrs(), 1) + p1, err := dawdle.NewProxy("tcp", "", c1.ClusterAddrs()[0], + dawdle.WithListener(pl1), + dawdle.WithRbufSize(256), + dawdle.WithWbufSize(256), + ) + require.NoError(err) + defer p1.Close() + require.NotEmpty(t, p1.ListenerAddr()) + + // ************* + // ** Proxy 2 ** + // ************* + require.Len(c2.ClusterAddrs(), 1) + p2, err := dawdle.NewProxy("tcp", "", c2.ClusterAddrs()[0], + dawdle.WithListener(pl2), + dawdle.WithRbufSize(256), + dawdle.WithWbufSize(256), + ) + require.NoError(err) + defer p2.Close() + require.NotEmpty(t, p2.ListenerAddr()) + + // ************ + // ** Worker ** + // ************ + w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{ + WorkerAuthKms: c1.Config().WorkerAuthKms, + InitialControllers: []string{p1.ListenerAddr(), p2.ListenerAddr()}, + Logger: logger.Named("w1"), + StatusGracePeriodDuration: workerGracePeriod(burdenCase), + }) + defer w1.Shutdown() + + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + require.NoError(err) + err = c1.WaitForNextWorkerStatusUpdate(w1.Name()) + require.NoError(err) + err = c2.WaitForNextWorkerStatusUpdate(w1.Name()) + require.NoError(err) + expectWorkers(t, c1, w1) + expectWorkers(t, c2, w1) + + // Use an independent context for test things that take a context so + // that we aren't tied to any timeouts in the controller, etc. This + // can interfere with some of the test operations. + ctx := context.Background() + + // Connect target + client := c1.Client() + client.SetToken(c1.Token().Token) + tcl := targets.NewClient(client) + tgt, err := tcl.Read(ctx, "ttcp_1234567890") + require.NoError(err) + require.NotNil(tgt) + + // Create test server, update default port on target + ts := newTestTcpServer(t, logger) + require.NotNil(ts) + defer ts.Close() + tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port()), targets.WithSessionConnectionLimit(-1)) + require.NoError(err) + require.NotNil(tgt) + + // Authorize and connect + sess := newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") + sConn := sess.Connect(ctx, t) + + // Run initial send/receive test, make sure things are working + logger.Debug("running initial send/recv test") + sConn.TestSendRecvAll(t) + + // Kill connection to first controller, and run test again, should + // pass, deferring to other controller. Wait for the next + // successful status report to ensure this. + logger.Debug("pausing link to controller #1") + p1.Pause() + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + require.NoError(err) + sConn.TestSendRecvAll(t) + + // Resume first controller, pause second. This one should work too. + logger.Debug("pausing link to controller #2, resuming #1") + p1.Resume() + p2.Pause() + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + require.NoError(err) + sConn.TestSendRecvAll(t) + + // Kill the first controller connection again. This one should fail + // due to lack of any connection. + logger.Debug("pausing link to controller #1 again, both connections should be offline") + p1.Pause() + + // Wait for failure connection state (depends on burden case) + switch burdenCase { + case timeoutBurdenTypeWorker: + // Wait on worker, then check controller + sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusConnected) + + case timeoutBurdenTypeController: + // Wait on controller, then check worker + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed) + sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusConnected) + + default: + // Should be closed on both worker and controller. Wait on + // worker then check controller. + sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed) + } - pl1, err := net.Listen("tcp", "localhost:0") - require.NoError(err) - c1 := controller.NewTestController(t, &controller.TestControllerOpts{ - Config: conf1, - InitialResourcesSuffix: "1234567890", - Logger: logger.Named("c1"), - PublicClusterAddr: pl1.Addr().String(), - }) - defer c1.Shutdown() + // Run send/receive test again to check expected connection-level + // behavior + if burdenCase == timeoutBurdenTypeController { + // Burden on controller, should be successful until connection + // resumes + sConn.TestSendRecvAll(t) + } else { + // Connection should die in other cases + sConn.TestSendRecvFail(t) + } - // ****************** - // ** Controller 2 ** - // ****************** - pl2, err := net.Listen("tcp", "localhost:0") - require.NoError(err) - c2 := c1.AddClusterControllerMember(t, &controller.TestControllerOpts{ - Logger: c1.Config().Logger.ResetNamed("c2"), - PublicClusterAddr: pl2.Addr().String(), - }) - defer c2.Shutdown() - expectWorkers(t, c1) - expectWorkers(t, c2) - - // ************* - // ** Proxy 1 ** - // ************* - require.Len(c1.ClusterAddrs(), 1) - p1, err := dawdle.NewProxy("tcp", "", c1.ClusterAddrs()[0], - dawdle.WithListener(pl1), - dawdle.WithRbufSize(512), - dawdle.WithWbufSize(512), - ) - require.NoError(err) - defer p1.Close() - require.NotEmpty(t, p1.ListenerAddr()) - - // ************* - // ** Proxy 2 ** - // ************* - require.Len(c2.ClusterAddrs(), 1) - p2, err := dawdle.NewProxy("tcp", "", c2.ClusterAddrs()[0], - dawdle.WithListener(pl2), - dawdle.WithRbufSize(512), - dawdle.WithWbufSize(512), - ) - require.NoError(err) - defer p2.Close() - require.NotEmpty(t, p2.ListenerAddr()) - - // ************ - // ** Worker ** - // ************ - w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{ - WorkerAuthKms: c1.Config().WorkerAuthKms, - InitialControllers: []string{p1.ListenerAddr(), p2.ListenerAddr()}, - Logger: logger.Named("w1"), - }) - defer w1.Shutdown() - - time.Sleep(10 * time.Second) - expectWorkers(t, c1, w1) - expectWorkers(t, c2, w1) - - // Use an independent context for test things that take a context so - // that we aren't tied to any timeouts in the controller, etc. This - // can interfere with some of the test operations. - ctx := context.Background() - - // Connect target - client := c1.Client() - client.SetToken(c1.Token().Token) - tcl := targets.NewClient(client) - tgt, err := tcl.Read(ctx, "ttcp_1234567890") - require.NoError(err) - require.NotNil(tgt) + // Finally resume both, try again. Should behave as per normal. + logger.Debug("resuming connections to both controllers") + p1.Resume() + p2.Resume() + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + require.NoError(err) + + // Do something post-reconnect depending on burden case. Note in + // the default case, both worker and controller should be + // relatively in sync, so we don't worry about these + // post-reconnection assertions. + switch burdenCase { + case timeoutBurdenTypeWorker: + // If we are expecting the worker to be the source of truth of + // a connection status, ensure that our old session's + // connections are actually closed now that the worker is + // properly reporting in again. + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed) + + case timeoutBurdenTypeController: + // If we are expecting the controller to be the source of + // truth, the connection should now be forcibly closed after + // the worker gets a status change request back. + sConn.TestSendRecvFail(t) + } - // Create test server, update default port on target - ts := newTestTcpServer(t, logger) - require.NotNil(ts) - defer ts.Close() - tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port())) - require.NoError(err) - require.NotNil(tgt) - - // Authorize and connect - sess := newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") - sConn := sess.Connect(ctx, t, logger) - - // Run initial send/receive test, make sure things are working - sConn.TestSendRecvAll(t) - - // Kill connection to first controller, and run test again, should - // pass, deferring to other controller. - p1.Pause() - sConn.TestSendRecvAll(t) - - // Resume first controller, pause second. This one should work too. - p1.Resume() - p2.Pause() - sConn.TestSendRecvAll(t) - - // Kill the first controller connection again. This one should fail - // due to lack of any connection. - p1.Pause() - sConn.TestSendRecvFail(t) - - // Assert we have no connections left (should be default behavior) - sess.TestNoConnectionsLeft(t) - - // Assert connection has been removed from the local worker state - sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) - - // Finally resume both, try again. Should behave as per normal. - p1.Resume() - p2.Resume() - time.Sleep(time.Second * 10) // Sleep to wait for worker to report back as healthy - sess = newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") // re-assign, other connection will close in t.Cleanup() - sConn = sess.Connect(ctx, t, logger) - sConn.TestSendRecvAll(t) + // Proceed with new connection test + logger.Debug("connecting to new session after resuming controller/worker link") + sess = newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") // re-assign, other connection will close in t.Cleanup() + sConn = sess.Connect(ctx, t) + sConn.TestSendRecvAll(t) + } } // testSession represents an authorized session. @@ -392,10 +579,73 @@ func (s *testSession) connect(ctx context.Context, t *testing.T) net.Conn { return websocket.NetConn(ctx, conn, websocket.MessageBinary) } -// TestNoConnectionsLeft asserts that there are no connections left. -func (s *testSession) TestNoConnectionsLeft(t *testing.T) { +// ExpectConnectionStateOnController waits until all connections in a +// session have transitioned to a particular state on the controller. +func (s *testSession) ExpectConnectionStateOnController( + ctx context.Context, + t *testing.T, + tc *controller.TestController, + expectState session.ConnectionStatus, +) { t.Helper() - require.Zero(t, s.connectionsLeft) + require := require.New(t) + assert := assert.New(t) + + ctx, cancel := context.WithTimeout(ctx, expectConnectionStateOnControllerTimeout) + defer cancel() + + // This is just for initialization of the actual state set. + const sessionStatusUnknown session.ConnectionStatus = "unknown" + + // Get all connections for the session on the controller. + sessionRepo, err := tc.Controller().SessionRepoFn() + require.NoError(err) + + conns, err := sessionRepo.ListConnectionsBySessionId(ctx, s.sessionId) + require.NoError(err) + // To avoid misleading passing tests, we require this test be used + // with sessions with connections.. + require.Greater(len(conns), 0, "should have at least one connection") + + // Make a set of states, 1 per connection + actualStates := make([]session.ConnectionStatus, len(conns)) + for i := range actualStates { + actualStates[i] = sessionStatusUnknown + } + + // Make expect set for comparison + expectStates := make([]session.ConnectionStatus, len(conns)) + for i := range expectStates { + expectStates[i] = expectState + } + + for { + if ctx.Err() != nil { + break + } + + for i, conn := range conns { + _, states, err := sessionRepo.LookupConnection(ctx, conn.PublicId, nil) + require.NoError(err) + // Look at the first state in the returned list, which will + // be the most recent state. + actualStates[i] = states[0].Status + } + + if reflect.DeepEqual(expectStates, actualStates) { + break + } + + time.Sleep(time.Second) + } + + // "non-fatal" assert here, so that we can surface both timeouts + // and invalid state + assert.NoError(ctx.Err()) + + // Assert + require.Equal(expectStates, actualStates) + s.logger.Debug("successfully asserted all connection states on controller", "expected_states", expectStates, "actual_states", actualStates) } // ExpectConnectionStateOnWorker waits until all connections in a @@ -491,7 +741,6 @@ type testSessionConnection struct { func (s *testSession) Connect( ctx context.Context, t *testing.T, // Just to add cleanup - logger hclog.Logger, ) *testSessionConnection { t.Helper() require := require.New(t) @@ -504,7 +753,7 @@ func (s *testSession) Connect( return &testSessionConnection{ conn: conn, - logger: logger, + logger: s.logger, } } @@ -518,6 +767,12 @@ func (s *testSession) Connect( func (c *testSessionConnection) testSendRecv(t *testing.T) bool { t.Helper() require := require.New(t) + + // This is a fairly arbitrary value, as the send/recv is + // instantaneous. The main key here is just to make sure that we do + // it a reasonable amount of times to know the connection is + // stable. + const testSendRecvSendMax = 100 for i := uint32(0); i < testSendRecvSendMax; i++ { // Shuttle over the sequence number as base64. err := binary.Write(c.conn, binary.LittleEndian, i) @@ -536,7 +791,7 @@ func (c *testSessionConnection) testSendRecv(t *testing.T) bool { var j uint32 err = binary.Read(c.conn, binary.LittleEndian, &j) if err != nil { - c.logger.Debug("received error during read", "err", err) + c.logger.Debug("received error during read", "err", err, "num_successfully_sent", i) if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, websocket.CloseError{Code: websocket.StatusPolicyViolation, Reason: "timed out"}) { @@ -549,9 +804,10 @@ func (c *testSessionConnection) testSendRecv(t *testing.T) bool { require.Equal(j, i) // Sleep 1s - time.Sleep(time.Second) + // time.Sleep(time.Second) } + c.logger.Debug("finished send/recv successfully", "num_successfully_sent", testSendRecvSendMax) return true } @@ -560,6 +816,7 @@ func (c *testSessionConnection) testSendRecv(t *testing.T) bool { func (c *testSessionConnection) TestSendRecvAll(t *testing.T) { t.Helper() require.True(t, c.testSendRecv(t)) + c.logger.Debug("successfully asserted send/recv as passing") } // TestSendRecvFail asserts that we were able to send/recv all pings @@ -567,6 +824,7 @@ func (c *testSessionConnection) TestSendRecvAll(t *testing.T) { func (c *testSessionConnection) TestSendRecvFail(t *testing.T) { t.Helper() require.False(t, c.testSendRecv(t)) + c.logger.Debug("successfully asserted send/recv as failing") } type testTcpServer struct {