diff --git a/internal/cmd/base/servers.go b/internal/cmd/base/servers.go index 3e376c0ee9..0560d8a57f 100644 --- a/internal/cmd/base/servers.go +++ b/internal/cmd/base/servers.go @@ -15,6 +15,7 @@ import ( "strings" "sync" "syscall" + "time" "github.com/armon/go-metrics" berrors "github.com/hashicorp/boundary/internal/errors" @@ -24,6 +25,7 @@ import ( "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/observability/event" + "github.com/hashicorp/boundary/internal/servers" "github.com/hashicorp/boundary/internal/types/scope" "github.com/hashicorp/boundary/version" "github.com/hashicorp/go-hclog" @@ -40,6 +42,23 @@ import ( "google.golang.org/grpc/grpclog" ) +const ( + // defaultStatusGracePeriod is the default status grace period, or the period + // of time that we will go without a status report before we start + // disconnecting and marking connections as closed. This is tied to the + // server default liveness setting, a related value. See the servers package + // for more details. + defaultStatusGracePeriod = servers.DefaultLiveness + + // statusGracePeriodEnvVar is the environment variable that can be used to + // configure the status grace period. This setting is provided in seconds, + // and can never be lower than the default status grace period defined above. + // + // TODO: This value is temporary, it will be removed once we have a better + // story/direction on attributes and system defaults. + statusGracePeriodEnvVar = "BOUNDARY_STATUS_GRACE_PERIOD" +) + type Server struct { *Command @@ -101,6 +120,11 @@ type Server struct { DevDatabaseCleanupFunc func() error Database *gorm.DB + + // StatusGracePeriodDuration represents the period of time (as a + // duration) that the controller will wait before marking + // connections from a disconnected worker as invalid. + StatusGracePeriodDuration time.Duration } func NewServer(cmd *Command) *Server { @@ -675,3 +699,51 @@ func MakeSighupCh() chan struct{} { }() return resultCh } + +// SetStatusGracePeriodDuration sets the value for +// StatusGracePeriodDuration. +// +// The grace period is the length of time we allow connections to run +// on a worker in the event of an error sending status updates. The +// period is defined the length of time since the last successful +// update. +// +// The setting is derived from one of the following, in order: +// +// * Via the supplied value if non-zero. +// * BOUNDARY_STATUS_GRACE_PERIOD, if defined, can be set to an +// integer value to define the setting. +// * If either of these is missing, the default is used. See the +// defaultStatusGracePeriod value for the default value. +// +// The minimum setting for this value is the default setting. Values +// below this will be reset to the default. +func (s *Server) SetStatusGracePeriodDuration(value time.Duration) { + var result time.Duration + switch { + case value > 0: + result = value + case os.Getenv(statusGracePeriodEnvVar) != "": + // TODO: See the description of the constant for more details on + // this env var + v := os.Getenv(statusGracePeriodEnvVar) + n, err := strconv.Atoi(v) + if err != nil { + s.Logger.Error(fmt.Sprintf("could not read setting for %s", statusGracePeriodEnvVar), + "err", err, + "value", v, + ) + break + } + + result = time.Second * time.Duration(n) + } + + if result < defaultStatusGracePeriod { + s.Logger.Debug("invalid grace period setting or none provided, using default", "value", result, "default", defaultStatusGracePeriod) + result = defaultStatusGracePeriod + } + + s.Logger.Debug("session cleanup in effect, connections will be terminated if status reports cannot be made", "grace_period", result) + s.StatusGracePeriodDuration = result +} diff --git a/internal/cmd/commands/dev/dev.go b/internal/cmd/commands/dev/dev.go index eaecebe511..742e28b7ee 100644 --- a/internal/cmd/commands/dev/dev.go +++ b/internal/cmd/commands/dev/dev.go @@ -425,6 +425,10 @@ func (c *Command) Run(args []string) int { return base.CommandCliError } + // Initialize status grace period (0 denotes using env or default + // here) + c.SetStatusGracePeriodDuration(0) + base.StartMemProfiler(c.Logger) if err := c.SetupMetrics(c.UI, c.Config.Telemetry); err != nil { diff --git a/internal/cmd/commands/server/server.go b/internal/cmd/commands/server/server.go index f663a609fc..ef4a498f68 100644 --- a/internal/cmd/commands/server/server.go +++ b/internal/cmd/commands/server/server.go @@ -156,6 +156,10 @@ func (c *Command) Run(args []string) int { return base.CommandUserError } + // Initialize status grace period (0 denotes using env or default + // here) + c.SetStatusGracePeriodDuration(0) + base.StartMemProfiler(c.Logger) if !c.skipMetrics { diff --git a/internal/cmd/config/config.go b/internal/cmd/config/config.go index b51bbb7558..551cdb29d4 100644 --- a/internal/cmd/config/config.go +++ b/internal/cmd/config/config.go @@ -119,6 +119,13 @@ type Controller struct { // denoted by time.Duration AuthTokenTimeToStale interface{} `hcl:"auth_token_time_to_stale"` AuthTokenTimeToStaleDuration time.Duration + + // StatusGracePeriod represents the period of time (as a duration) that the + // controller will wait before marking connections from a disconnected worker + // as invalid. + // + // TODO: This field is currently internal. + StatusGracePeriodDuration time.Duration `hcl:"-"` } type Worker struct { @@ -132,6 +139,13 @@ type Worker struct { // key=value syntax. This is trued up in the Parse function below. TagsRaw interface{} `hcl:"tags"` Tags map[string][]string `hcl:"-"` + + // StatusGracePeriod represents the period of time (as a duration) that the + // worker will wait before disconnecting connections if it cannot make a + // status report to a controller. + // + // TODO: This field is currently internal. + StatusGracePeriodDuration time.Duration `hcl:"-"` } type Database struct { diff --git a/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs.up.sql b/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs.up.sql new file mode 100644 index 0000000000..09b2468802 --- /dev/null +++ b/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs.up.sql @@ -0,0 +1,23 @@ +begin; + + create function wt_sub_seconds(sec integer, ts timestamp with time zone) + returns timestamp with time zone + as $$ + select ts - sec * '1 second'::interval; + $$ language sql + stable + returns null on null input; + comment on function wt_add_seconds is + 'wt_sub_seconds returns ts - sec.'; + + create function wt_sub_seconds_from_now(sec integer) + returns timestamp with time zone + as $$ + select wt_sub_seconds(sec, current_timestamp); + $$ language sql + stable + returns null on null input; + comment on function wt_add_seconds_to_now is + 'wt_sub_seconds_from_now returns current_timestamp - sec.'; + +commit; diff --git a/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs_test.go b/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs_test.go new file mode 100644 index 0000000000..3eda4f90f0 --- /dev/null +++ b/internal/db/schema/migrations/postgres/12/01_timestamp_sub_funcs_test.go @@ -0,0 +1,61 @@ +package migration + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/hashicorp/boundary/internal/db/schema" + "github.com/hashicorp/boundary/internal/docker" + "github.com/stretchr/testify/require" +) + +const targetMigration = 12001 + +func TestWtSubSeconds(t *testing.T) { + t.Parallel() + require := require.New(t) + ctx := context.Background() + d := testSetupDb(ctx, t) + + // Test by subtracing a day from the test date + sourceTime, err := time.Parse(time.RFC3339, "2006-01-02T15:04:05+07:00") + require.NoError(err) + expectedTime := sourceTime.Add(time.Second * -86400) + + var actualTime time.Time + row := d.QueryRowContext(ctx, "select wt_sub_seconds($1, $2)", 86400, sourceTime) + require.NoError(row.Scan(&actualTime)) + require.True(expectedTime.Equal(actualTime)) +} + +func testSetupDb(ctx context.Context, t *testing.T) *sql.DB { + t.Helper() + require := require.New(t) + + dialect := "postgres" + + c, u, _, err := docker.StartDbInDocker(dialect) + require.NoError(err) + t.Cleanup(func() { + require.NoError(c()) + }) + d, err := sql.Open(dialect, u) + require.NoError(err) + + oState := schema.TestCloneMigrationStates(t) + nState := schema.TestCreatePartialMigrationState(oState["postgres"], targetMigration) + oState["postgres"] = nState + + m, err := schema.NewManager(ctx, dialect, d, schema.WithMigrationStates(oState)) + require.NoError(err) + + require.NoError(m.RollForward(ctx)) + state, err := m.CurrentState(ctx) + require.NoError(err) + require.Equal(targetMigration, state.DatabaseSchemaVersion) + require.False(state.Dirty) + + return d +} diff --git a/internal/db/schema/postgres_migration.gen.go b/internal/db/schema/postgres_migration.gen.go index d31e3d11ae..8b501a12c5 100644 --- a/internal/db/schema/postgres_migration.gen.go +++ b/internal/db/schema/postgres_migration.gen.go @@ -4,7 +4,7 @@ package schema func init() { migrationStates["postgres"] = migrationState{ - binarySchemaVersion: 11001, + binarySchemaVersion: 12001, upMigrations: map[int][]byte{ 1: []byte(` create domain wt_public_id as text @@ -6100,6 +6100,27 @@ alter table server foreign key (type) references server_type_enm(name) on update cascade on delete restrict; +`), + 12001: []byte(` +create function wt_sub_seconds(sec integer, ts timestamp with time zone) + returns timestamp with time zone + as $$ + select ts - sec * '1 second'::interval; + $$ language sql + stable + returns null on null input; + comment on function wt_add_seconds is + 'wt_sub_seconds returns ts - sec.'; + + create function wt_sub_seconds_from_now(sec integer) + returns timestamp with time zone + as $$ + select wt_sub_seconds(sec, current_timestamp); + $$ language sql + stable + returns null on null input; + comment on function wt_add_seconds_to_now is + 'wt_sub_seconds_from_now returns current_timestamp - sec.'; `), 2001: []byte(` -- log_migration entries represent logs generated during migrations diff --git a/internal/servers/controller/controller.go b/internal/servers/controller/controller.go index f4366ba004..a9029547b4 100644 --- a/internal/servers/controller/controller.go +++ b/internal/servers/controller/controller.go @@ -38,7 +38,7 @@ type Controller struct { workerAuthCache *cache.Cache - // Used for testing + // Used for testing and tracking worker health workerStatusUpdateTimes *sync.Map // Repo factory methods @@ -117,7 +117,9 @@ func New(conf *Config) (*Controller, error) { jobRepoFn := func() (*job.Repository, error) { return job.NewRepository(dbase, dbase, c.kms) } - c.scheduler, err = scheduler.New(c.conf.RawConfig.Controller.Name, jobRepoFn, c.logger) + // TODO: the RunJobsLimit is temporary until a better fix gets in. This + // currently caps the scheduler at running 10 jobs per interval. + c.scheduler, err = scheduler.New(c.conf.RawConfig.Controller.Name, jobRepoFn, c.logger, scheduler.WithRunJobsLimit(10)) if err != nil { return nil, fmt.Errorf("error creating new scheduler: %w", err) } @@ -182,7 +184,29 @@ func (c *Controller) Start() error { func (c *Controller) registerJobs() error { rw := db.New(c.conf.Database) - return vault.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms, c.logger) + if err := vault.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms, c.logger); err != nil { + return err + } + + if err := c.registerSessionCleanupJob(); err != nil { + return err + } + + return nil +} + +// registerSessionCleanupJob is a helper method to abstract +// registering the session cleanup job specifically. +func (c *Controller) registerSessionCleanupJob() error { + sessionCleanupJob, err := newSessionCleanupJob(c.logger, c.SessionRepoFn, int(c.conf.StatusGracePeriodDuration.Seconds())) + if err != nil { + return fmt.Errorf("error creating session cleanup job: %w", err) + } + if err = c.scheduler.RegisterJob(c.baseContext, sessionCleanupJob); err != nil { + return fmt.Errorf("error registering session cleanup job: %w", err) + } + + return nil } func (c *Controller) Shutdown(serversOnly bool) error { diff --git a/internal/servers/controller/handlers/workers/worker_service.go b/internal/servers/controller/handlers/workers/worker_service.go index 4135ccf7b0..56520a03f8 100644 --- a/internal/servers/controller/handlers/workers/worker_service.go +++ b/internal/servers/controller/handlers/workers/worker_service.go @@ -76,12 +76,19 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques Controllers: controllers, } - // TODO (jeff, 05-2021): We should possibly list all connections here with - // their statuses, and check those against the found connections below. If - // any we think should be closed are marked open by the worker, the worker - // should be told to close them. - - var foundConns []string + var ( + // For tracking the reported open connections. + reportedOpenConns []string + // For tracking the session IDs we've already requested + // cancellation for. We won't need to add connection cancel + // requests for these because canceling the session terminates the + // connections. + requestedSessionCancelIds []string + ) + + // This is a map of all sessions and their statuses. We keep track of + // this for easy lookup if we need to make change requests. + sessionStatuses := make(map[string]pbs.SESSIONSTATUS) for _, jobStatus := range req.GetJobs() { switch jobStatus.Job.GetType() { @@ -92,6 +99,9 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques return nil, status.Error(codes.Internal, "Error getting session info at status time") } + // Record status. + sessionStatuses[si.GetSessionId()] = si.Status + // Check connections before potentially bypassing the rest of the // logic in the switch on si.Status. sessConns := si.GetConnections() @@ -103,7 +113,7 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques // report as found, so that we should attempt to close it. // Note that unspecified is the default state for the enum // but it's not ever explicitly set by us. - foundConns = append(foundConns, conn.GetConnectionId()) + reportedOpenConns = append(reportedOpenConns, conn.GetConnectionId()) } } @@ -127,7 +137,7 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques } // If the session from the DB is in canceling status, and we're // here, it means the job is in pending or active; cancel it. If - // it's in termianted status something went wrong and we're + // it's in terminated status something went wrong and we're // mismatched, so ensure we cancel it also. currState := sessionInfo.States[0].Status if currState.ProtoVal() != si.Status { @@ -148,13 +158,53 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques }, RequestType: pbs.CHANGETYPE_CHANGETYPE_UPDATE_STATE, }) + // Log the session ID so we don't add a duplicate change + // request on connection normalization. + requestedSessionCancelIds = append(requestedSessionCancelIds, sessionId) } } } } - // Run our connection cleanup function - closedConns, err := sessRepo.CloseDeadConnectionsOnWorkerReport(ctx, req.Worker.PrivateId, foundConns) + // Normalize the current state of connections on the worker side + // with the data from the controller. In other words, if one of our + // found connections isn't supposed to be alive still, kill it. + // + // This is separate from the above session normalization and is + // additive to it, we don't add sessions that have already been + // added there as canceling sessions already closes the + // connections. + shouldCloseConnections, err := sessRepo.ShouldCloseConnectionsOnWorker(ctx, reportedOpenConns, requestedSessionCancelIds) + if err != nil { + return nil, status.Errorf(codes.Internal, "Error fetching connections that should be closed: %v", err) + } + + for sessionId, connIds := range shouldCloseConnections { + var connChanges []*pbs.Connection + for _, connId := range connIds { + connChanges = append(connChanges, &pbs.Connection{ + ConnectionId: connId, + Status: session.StatusClosed.ProtoVal(), + }) + } + + ret.JobsRequests = append(ret.JobsRequests, &pbs.JobChangeRequest{ + Job: &pbs.Job{ + Type: pbs.JOBTYPE_JOBTYPE_SESSION, + JobInfo: &pbs.Job_SessionInfo{ + SessionInfo: &pbs.SessionJobInfo{ + SessionId: sessionId, + Status: sessionStatuses[sessionId], + Connections: connChanges, + }, + }, + }, + RequestType: pbs.CHANGETYPE_CHANGETYPE_UPDATE_STATE, + }) + } + + // Run our controller-side cleanup function. + closedConns, err := sessRepo.CloseDeadConnectionsForWorker(ctx, req.Worker.PrivateId, reportedOpenConns) if err != nil { return nil, status.Errorf(codes.Internal, "Error closing dead conns for worker %s: %v", req.Worker.PrivateId, err) } diff --git a/internal/servers/controller/session_cleanup_job.go b/internal/servers/controller/session_cleanup_job.go new file mode 100644 index 0000000000..fd59e6fa8d --- /dev/null +++ b/internal/servers/controller/session_cleanup_job.go @@ -0,0 +1,133 @@ +package controller + +import ( + "context" + "fmt" + "time" + + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/scheduler" + "github.com/hashicorp/boundary/internal/servers/controller/common" + "github.com/hashicorp/boundary/internal/session" + "github.com/hashicorp/go-hclog" +) + +// sessionCleanupJob defines a periodic job that monitors workers for +// loss of connection and terminates connections on workers that have +// not sent a heartbeat in a significant period of time. +// +// Relevant connections are simply marked as disconnected in the +// database. Connections will be independently terminated by the +// worker, or the event of a synchronization issue between the two, +// the controller will win out and order that the connections be +// closed on the worker. +type sessionCleanupJob struct { + logger hclog.Logger + sessionRepoFn common.SessionRepoFactory + + // The amount of time to give disconnected workers before marking + // their connections as closed. + gracePeriod int + + // The total number of connections closed in the last run. + totalClosed int +} + +// newSessionCleanupJob instantiates the session cleanup job. +func newSessionCleanupJob( + logger hclog.Logger, + sessionRepoFn common.SessionRepoFactory, + gracePeriod int, +) (*sessionCleanupJob, error) { + const op = "controller.newNewSessionCleanupJob" + switch { + case logger == nil: + return nil, errors.New(errors.InvalidParameter, op, "missing logger") + case sessionRepoFn == nil: + return nil, errors.New(errors.InvalidParameter, op, "missing sessionRepoFn") + case gracePeriod < session.DeadWorkerConnCloseMinGrace: + return nil, errors.New( + errors.InvalidParameter, op, fmt.Sprintf("invalid gracePeriod, must be greater than %d", session.DeadWorkerConnCloseMinGrace)) + } + + return &sessionCleanupJob{ + logger: logger, + sessionRepoFn: sessionRepoFn, + gracePeriod: gracePeriod, + }, nil +} + +// Name returns a short, unique name for the job. +func (j *sessionCleanupJob) Name() string { return "session_cleanup" } + +// Description returns the description for the job. +func (j *sessionCleanupJob) Description() string { + return "Clean up session connections from disconnected workers" +} + +// NextRunIn returns the next run time after a job is completed. +// +// The next run time is defined for sessionCleanupJob as one second. +// This is because the job should run continuously to terminate +// connections as soon as a worker has not reported in for a long +// enough time. Only one job will ever run at once, so there is no +// reason why it cannot run again immediately. +func (j *sessionCleanupJob) NextRunIn() (time.Duration, error) { return time.Second, nil } + +// Status returns the status of the running job. +func (j *sessionCleanupJob) Status() scheduler.JobStatus { + return scheduler.JobStatus{ + Completed: j.totalClosed, + Total: j.totalClosed, + } +} + +// Run executes the job. +func (j *sessionCleanupJob) Run(ctx context.Context) error { + const op = "controller.(sessionCleanupJob).Run" + j.logger.Debug( + "starting job", + "op", op, + ) + j.totalClosed = 0 + + // Load repos. + sessionRepo, err := j.sessionRepoFn() + if err != nil { + return errors.Wrap(err, op, errors.WithMsg("error getting session repo")) + } + + // Run the atomic dead worker cleanup job. + results, err := sessionRepo.CloseConnectionsForDeadWorkers(ctx, j.gracePeriod) + if err != nil { + return errors.Wrap(err, op) + } + + if len(results) < 1 { + j.logger.Debug( + "all workers OK, no connections to close", + "op", op, + ) + } else { + for _, result := range results { + // Log a closed connection message for each worker as a warning + j.logger.Warn( + "worker has not reported status within acceptable grace period, all connections closed", + "op", op, + "private_id", result.ServerId, + "update_time", result.LastUpdateTime, + "grace_period_seconds", j.gracePeriod, + "number_connections_closed", result.NumberConnectionsClosed, + ) + + j.totalClosed += result.NumberConnectionsClosed + } + } + + j.logger.Debug( + "job finished", + "op", op, + "total_connections_closed", j.totalClosed, + ) + return nil +} diff --git a/internal/servers/controller/session_cleanup_job_test.go b/internal/servers/controller/session_cleanup_job_test.go new file mode 100644 index 0000000000..b83dd67534 --- /dev/null +++ b/internal/servers/controller/session_cleanup_job_test.go @@ -0,0 +1,170 @@ +package controller + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/scheduler" + "github.com/hashicorp/boundary/internal/servers" + "github.com/hashicorp/boundary/internal/session" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// assert the interface +var _ = scheduler.Job(new(sessionCleanupJob)) + +// This test has been largely adapted from +// TestRepository_CloseDeadConnectionsOnWorker in +// internal/session/repository_connection_test.go. +func TestSessionCleanupJob(t *testing.T) { + t.Parallel() + require, assert := require.New(t), assert.New(t) + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + serversRepo, err := servers.NewRepository(rw, rw, kms) + require.NoError(err) + sessionRepo, err := session.NewRepository(rw, rw, kms) + require.NoError(err) + ctx := context.Background() + numConns := 12 + + // Create two "workers". One will remain untouched while the other "goes + // away and comes back" (worker 2). + worker1 := session.TestWorker(t, conn, wrapper, session.WithServerId("worker1")) + worker2 := session.TestWorker(t, conn, wrapper, session.WithServerId("worker2")) + + // Create a few sessions on each, activate, and authorize a connection + var connIds []string + connIdsByWorker := make(map[string][]string) + for i := 0; i < numConns; i++ { + serverId := worker1.PrivateId + if i%2 == 0 { + serverId = worker2.PrivateId + } + sess := session.TestDefaultSession(t, conn, wrapper, iamRepo, session.WithServerId(serverId), session.WithDbOpts(db.WithSkipVetForWrite(true))) + sess, _, err = sessionRepo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, serverId, "worker", []byte("foo")) + require.NoError(err) + c, cs, _, err := sessionRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) + require.NoError(err) + require.Len(cs, 1) + require.Equal(session.StatusAuthorized, cs[0].Status) + connIds = append(connIds, c.GetPublicId()) + if i%2 == 0 { + connIdsByWorker[worker2.PrivateId] = append(connIdsByWorker[worker2.PrivateId], c.GetPublicId()) + } else { + connIdsByWorker[worker1.PrivateId] = append(connIdsByWorker[worker1.PrivateId], c.GetPublicId()) + } + } + + // Mark half of the connections connected and leave the others authorized. + // This is just to ensure we have a spread when we test it out. + for i, connId := range connIds { + if i%2 == 0 { + _, cs, err := sessionRepo.ConnectConnection(ctx, session.ConnectWith{ + ConnectionId: connId, + ClientTcpAddress: "127.0.0.1", + ClientTcpPort: 22, + EndpointTcpAddress: "127.0.0.1", + EndpointTcpPort: 22, + }) + require.NoError(err) + require.Len(cs, 2) + var foundAuthorized, foundConnected bool + for _, status := range cs { + if status.Status == session.StatusAuthorized { + foundAuthorized = true + } + if status.Status == session.StatusConnected { + foundConnected = true + } + } + require.True(foundAuthorized) + require.True(foundConnected) + } + } + + // Create the job. + job, err := newSessionCleanupJob( + hclog.New(&hclog.LoggerOptions{Level: hclog.Trace}), + func() (*session.Repository, error) { return sessionRepo, nil }, + session.DeadWorkerConnCloseMinGrace, + ) + require.NoError(err) + + // sleep the status grace period. + time.Sleep(time.Second * time.Duration(session.DeadWorkerConnCloseMinGrace)) + + // Push an upsert to the first worker so that its status has been + // updated. + _, rowsUpdated, err := serversRepo.UpsertServer(ctx, worker1, []servers.Option{}...) + require.NoError(err) + require.Equal(1, rowsUpdated) + + // Run the job. + require.NoError(job.Run(ctx)) + + // Assert connection state on both workers. + assertConnections := func(workerId string, closed bool) { + connIds, ok := connIdsByWorker[workerId] + require.True(ok) + require.Len(connIds, 6) + for _, connId := range connIds { + _, states, err := sessionRepo.LookupConnection(ctx, connId, nil) + require.NoError(err) + var foundClosed bool + for _, state := range states { + if state.Status == session.StatusClosed { + foundClosed = true + break + } + } + assert.Equal(closed, foundClosed) + } + } + + // Assert that all connections on the second worker are closed + assertConnections(worker2.PrivateId, true) + // Assert that all connections on the first worker are still open + assertConnections(worker1.PrivateId, false) +} + +func TestSessionCleanupJobNewJobErr(t *testing.T) { + t.Parallel() + const op = "controller.newNewSessionCleanupJob" + require := require.New(t) + + job, err := newSessionCleanupJob(nil, nil, 0) + require.Equal(err, errors.E( + errors.WithCode(errors.InvalidParameter), + errors.WithOp(op), + errors.WithMsg("missing logger"), + )) + require.Nil(job) + + job, err = newSessionCleanupJob(hclog.New(nil), nil, 0) + require.Equal(err, errors.E( + errors.WithCode(errors.InvalidParameter), + errors.WithOp(op), + errors.WithMsg("missing sessionRepoFn"), + )) + require.Nil(job) + + job, err = newSessionCleanupJob(hclog.New(nil), func() (*session.Repository, error) { return nil, nil }, 0) + require.Equal(err, errors.E( + errors.WithCode(errors.InvalidParameter), + errors.WithOp(op), + errors.WithMsg(fmt.Sprintf("invalid gracePeriod, must be greater than %d", session.DeadWorkerConnCloseMinGrace)), + )) + require.Nil(job) +} diff --git a/internal/servers/controller/testing.go b/internal/servers/controller/testing.go index 620dccfae5..67be160392 100644 --- a/internal/servers/controller/testing.go +++ b/internal/servers/controller/testing.go @@ -8,6 +8,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/api/authmethods" @@ -371,6 +372,10 @@ type TestControllerOpts struct { // A cluster address for overriding the advertised controller listener // (overrides address provided in config, if any) PublicClusterAddr string + + // The amount of time to wait before marking connections as closed when a + // worker has not reported in + StatusGracePeriodDuration time.Duration } func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { @@ -456,6 +461,9 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { t.Fatal(err) } + // Initialize status grace period + tc.b.SetStatusGracePeriodDuration(opts.StatusGracePeriodDuration) + if opts.Config.Controller == nil { opts.Config.Controller = new(config.Controller) } @@ -615,6 +623,7 @@ func (tc *TestController) AddClusterControllerMember(t *testing.T, opts *TestCon DisableKmsKeyCreation: true, DisableAuthMethodCreation: true, PublicClusterAddr: opts.PublicClusterAddr, + StatusGracePeriodDuration: opts.StatusGracePeriodDuration, } if opts.Logger != nil { nextOpts.Logger = opts.Logger @@ -629,3 +638,71 @@ func (tc *TestController) AddClusterControllerMember(t *testing.T, opts *TestCon } return NewTestController(t, nextOpts) } + +// WaitForNextWorkerStatusUpdate waits for the next status check from a worker to +// come in. If it does not come in within the default status grace +// period, this function returns an error. +func (tc *TestController) WaitForNextWorkerStatusUpdate(workerId string) error { + tc.Logger().Debug("waiting for next status report from worker", "worker", workerId) + + if err := tc.waitForNextWorkerStatusUpdate(workerId); err != nil { + tc.Logger().Error("error waiting for next status report from worker", "worker", workerId, "err", err) + return err + } + + tc.Logger().Debug("waiting for next status report from worker received successfully", "worker", workerId) + return nil +} + +func (tc *TestController) waitForNextWorkerStatusUpdate(workerId string) error { + waitStatusStart := time.Now() + ctx, cancel := context.WithTimeout(tc.ctx, tc.b.StatusGracePeriodDuration) + defer cancel() + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case <-time.After(time.Second): + // pass + } + + var waitStatusCurrent time.Time + var err error + tc.Controller().WorkerStatusUpdateTimes().Range(func(k, v interface{}) bool { + if k == nil || v == nil { + err = fmt.Errorf("nil key or value on entry: key=%#v value=%#v", k, v) + return false + } + + workerStatusUpdateId, ok := k.(string) + if !ok { + err = fmt.Errorf("unexpected type %T for key: key=%#v value=%#v", k, k, v) + return false + } + + workerStatusUpdateTime, ok := v.(time.Time) + if !ok { + err = fmt.Errorf("unexpected type %T for value: key=%#v value=%#v", k, k, v) + return false + } + + if workerStatusUpdateId == workerId { + waitStatusCurrent = workerStatusUpdateTime + return false + } + + return true + }) + + if err != nil { + return err + } + + if waitStatusCurrent.Sub(waitStatusStart) > 0 { + break + } + } + + return nil +} diff --git a/internal/servers/options.go b/internal/servers/options.go index 6c2b0cf510..f82f5a52d2 100644 --- a/internal/servers/options.go +++ b/internal/servers/options.go @@ -37,8 +37,9 @@ func WithLimit(limit int) Option { } } -// WithSkipVetForWrite provides an option to allow skipping vet checks to allow -// testing lower-level SQL triggers and constraints +// WithLiveness indicates how far back we want to search for server entries. +// Use 0 for the default liveness (15 seconds). A liveness value of -1 removes +// the liveliness condition. func WithLiveness(liveness time.Duration) Option { return func(o *options) { o.withLiveness = liveness diff --git a/internal/servers/repository.go b/internal/servers/repository.go index 53263c566d..c25a2a7592 100644 --- a/internal/servers/repository.go +++ b/internal/servers/repository.go @@ -13,7 +13,15 @@ import ( ) const ( - defaultLiveness = 15 * time.Second + // DefaultLiveness is the setting that controls the server "liveness" time, + // or the maximum allowable time that a worker can't send a status update to + // the controller for. After this, the server is considered dead, and it will + // be taken out of the rotation for allowable workers for connections, and + // connections will possibly start to be terminated and marked as closed + // depending on the grace period setting (see + // base.Server.StatusGracePeriodDuration). This value serves as the default + // and minimum allowable setting for the grace period. + DefaultLiveness = 15 * time.Second ) type ServerType string @@ -67,18 +75,27 @@ func (r *Repository) listServersWithReader(ctx context.Context, reader db.Reader opts := getOpts(opt...) liveness := opts.withLiveness if liveness == 0 { - liveness = defaultLiveness + liveness = DefaultLiveness } + + var where string + if liveness > 0 { + where = fmt.Sprintf("type = $1 and update_time > now() - interval '%d seconds'", uint32(liveness.Seconds())) + } else { + where = "type = $1" + } + var servers []*Server if err := reader.SearchWhere( ctx, &servers, - fmt.Sprintf("type = $1 and update_time > now() - interval '%d seconds'", uint32(liveness.Seconds())), + where, []interface{}{serverType}, db.WithLimit(-1), ); err != nil { return nil, errors.Wrap(err, "servers.listServersWithReader") } + return servers, nil } diff --git a/internal/servers/repository_test.go b/internal/servers/repository_test.go index f58736f0c3..b36f6d0ef3 100644 --- a/internal/servers/repository_test.go +++ b/internal/servers/repository_test.go @@ -1,12 +1,14 @@ package servers_test import ( + "context" "testing" "time" "github.com/hashicorp/boundary/api/roles" "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/servers" "github.com/hashicorp/boundary/internal/servers/controller" "github.com/hashicorp/boundary/internal/types/scope" @@ -180,3 +182,80 @@ func TestTagUpdatingListing(t *testing.T) { } require.Equal(exp, tags) } + +func TestListServersWithLiveness(t *testing.T) { + t.Parallel() + require := require.New(t) + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + kms := kms.TestKms(t, conn, wrapper) + serversRepo, err := servers.NewRepository(rw, rw, kms) + require.NoError(err) + ctx := context.Background() + + newServer := func(privateId string) *servers.Server { + result := &servers.Server{ + PrivateId: privateId, + Type: "worker", + Address: "127.0.0.1", + } + _, rowsUpdated, err := serversRepo.UpsertServer(ctx, result) + require.NoError(err) + require.Equal(1, rowsUpdated) + + return result + } + + server1 := newServer("test1") + server2 := newServer("test2") + server3 := newServer("test3") + + // Sleep the default liveness time (15sec currently) +1s + time.Sleep(time.Second * 16) + + // Push an upsert to the first worker so that its status has been + // updated. + _, rowsUpdated, err := serversRepo.UpsertServer(ctx, server1) + require.NoError(err) + require.Equal(1, rowsUpdated) + + requireIds := func(expected []string, actual []*servers.Server) { + require.Len(expected, len(actual)) + want := make(map[string]struct{}) + for _, v := range expected { + want[v] = struct{}{} + } + + got := make(map[string]struct{}) + for _, v := range actual { + got[v.PrivateId] = struct{}{} + } + + require.Equal(want, got) + } + + // Default liveness, should only list 1 + result, err := serversRepo.ListServers(ctx, servers.ServerTypeWorker) + require.NoError(err) + require.Len(result, 1) + requireIds([]string{server1.PrivateId}, result) + + // Upsert second server. + _, rowsUpdated, err = serversRepo.UpsertServer(ctx, server2) + require.NoError(err) + require.Equal(1, rowsUpdated) + + // Static liveness. Should get two, so long as this did not take + // more than 5s to execute. + result, err = serversRepo.ListServers(ctx, servers.ServerTypeWorker, servers.WithLiveness(time.Second*5)) + require.NoError(err) + require.Len(result, 2) + requireIds([]string{server1.PrivateId, server2.PrivateId}, result) + + // Liveness disabled, should get all three servers. + result, err = serversRepo.ListServers(ctx, servers.ServerTypeWorker, servers.WithLiveness(-1)) + require.NoError(err) + require.Len(result, 3) + requireIds([]string{server1.PrivateId, server2.PrivateId, server3.PrivateId}, result) +} diff --git a/internal/servers/worker/status.go b/internal/servers/worker/status.go index 821ceac7ba..6a083b699c 100644 --- a/internal/servers/worker/status.go +++ b/internal/servers/worker/status.go @@ -3,8 +3,6 @@ package worker import ( "context" "math/rand" - "os" - "strconv" "time" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" @@ -15,48 +13,10 @@ import ( // In the future we could make this configurable const ( - statusInterval = 2 * time.Second - statusTimeout = 5 * time.Second - defaultStatusGracePeriod = 30 * time.Second - statusGracePeriodEnvVar = "BOUNDARY_STATUS_GRACE_PERIOD" + statusInterval = 2 * time.Second + statusTimeout = 5 * time.Second ) -// statusGracePeriod returns the status grace period setting for this -// worker, in seconds. -// -// The grace period is the length of time we allow connections to run -// on a worker in the event of an error sending status updates. The -// period is defined the length of time since the last successful -// update. -// -// The setting is derived from one of the following: -// -// * BOUNDARY_STATUS_GRACE_PERIOD, if defined, can be set to an -// integer value to define the setting. -// * If this is missing, the default (30 seconds) is used. -// -func (w *Worker) statusGracePeriod() time.Duration { - if v := os.Getenv(statusGracePeriodEnvVar); v != "" { - n, err := strconv.Atoi(v) - if err != nil { - w.logger.Error("could not read setting for BOUNDARY_STATUS_GRACE_PERIOD, using default", - "err", err, - "value", v, - ) - return defaultStatusGracePeriod - } - - if n < 1 { - w.logger.Error("invalid setting for BOUNDARY_STATUS_GRACE_PERIOD, using default", "value", v) - return defaultStatusGracePeriod - } - - return time.Second * time.Duration(n) - } - - return defaultStatusGracePeriod -} - type LastStatusInformation struct { *pbs.StatusResponse StatusTime time.Time @@ -99,6 +59,37 @@ func (w *Worker) LastStatusSuccess() *LastStatusInformation { return w.lastStatusSuccess.Load().(*LastStatusInformation) } +// WaitForNextSuccessfulStatusUpdate waits for the next successful status. It's +// used by testing (and in the future, shutdown) in place of a more opaque and +// possibly unnecessarily long sleep for things like initial controller +// check-in, etc. +// +// The timeout is aligned with the worker's status grace period. A nil error +// means the status was sent successfully. +func (w *Worker) WaitForNextSuccessfulStatusUpdate() error { + w.logger.Debug("waiting for next status report to controller") + waitStatusStart := time.Now() + ctx, cancel := context.WithTimeout(w.baseContext, w.conf.StatusGracePeriodDuration) + defer cancel() + for { + select { + case <-time.After(time.Second): + // pass + + case <-ctx.Done(): + w.logger.Error("error waiting for next status report to controller", "err", ctx.Err()) + return ctx.Err() + } + + if w.lastSuccessfulStatusTime().Sub(waitStatusStart) > 0 { + break + } + } + + w.logger.Debug("next worker status update sent successfully") + return nil +} + func (w *Worker) sendWorkerStatus(cancelCtx context.Context) { // First send info as-is. We'll perform cleanup duties after we // get cancel/job change info back. @@ -196,6 +187,7 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context) { w.lastStatusSuccess.Store(&LastStatusInformation{StatusResponse: result, StatusTime: time.Now()}) for _, request := range result.GetJobsRequests() { + w.logger.Trace("got job request from controller", "request", request) switch request.GetRequestType() { case pbs.CHANGETYPE_CHANGETYPE_UPDATE_STATE: switch request.GetJob().GetType() { @@ -204,12 +196,25 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context) { sessionId := sessInfo.GetSessionId() siRaw, ok := w.sessionInfoMap.Load(sessionId) if !ok { - w.logger.Warn("asked to cancel session but could not find a local information for it", "session_id", sessionId) + w.logger.Warn("session change requested but could not find local information for it", "session_id", sessionId) continue } si := siRaw.(*sessionInfo) si.Lock() si.status = sessInfo.GetStatus() + // Update connection state if there are any connections in + // the request. + for _, conn := range sessInfo.GetConnections() { + connId := conn.GetConnectionId() + connInfo, ok := si.connInfoMap[conn.GetConnectionId()] + if !ok { + w.logger.Warn("connection change requested but could not find local information for it", "connection_id", connId) + continue + } + + connInfo.status = conn.GetStatus() + } + si.Unlock() } } @@ -240,22 +245,31 @@ func (w *Worker) cleanupConnections(cancelCtx context.Context, ignoreSessionStat si.status == pbs.SESSIONSTATUS_SESSIONSTATUS_CANCELING, si.status == pbs.SESSIONSTATUS_SESSIONSTATUS_TERMINATED, time.Until(si.lookupSessionResponse.Expiration.AsTime()) < 0: - var toClose int - for k, v := range si.connInfoMap { - if v.closeTime.IsZero() { - toClose++ - v.connCancel() - w.logger.Info("terminated connection due to cancellation or expiration", "session_id", si.id, "connection_id", k) - closeInfo[k] = si.id - } + // Cancel connections without regard to individual connection + // state. + closedIds := w.cancelConnections(si.connInfoMap, true) + for _, connId := range closedIds { + closeInfo[connId] = si.id + w.logClose(si.id, connId) } + // closeTime is marked by closeConnections iff the // status is returned for that connection as closed. If // the session is no longer valid and all connections // are marked closed, clean up the session. - if toClose == 0 { + if len(closedIds) == 0 { cleanSessionIds = append(cleanSessionIds, si.id) } + + default: + // Cancel connections *with* regard to individual connection + // state (ie: only ones that the controller has requested be + // terminated). + closedIds := w.cancelConnections(si.connInfoMap, false) + for _, connId := range closedIds { + closeInfo[connId] = si.id + w.logClose(si.id, connId) + } } return true @@ -277,6 +291,33 @@ func (w *Worker) cleanupConnections(cancelCtx context.Context, ignoreSessionStat } } +// cancelConnections is run by cleanupConnections to iterate over a +// session's connInfoMap and close connections based on the +// connection's state (or regardless if ignoreConnectionState is +// set). +// +// The returned map and slice are the maps of connection -> session, +// and sessions to completely remove from local state, respectively. +func (w *Worker) cancelConnections(connInfoMap map[string]*connInfo, ignoreConnectionState bool) []string { + var closedIds []string + for k, v := range connInfoMap { + if v.closeTime.IsZero() { + if !ignoreConnectionState && v.status != pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED { + continue + } + + v.connCancel() + closedIds = append(closedIds, k) + } + } + + return closedIds +} + +func (w *Worker) logClose(sessionId, connId string) { + w.logger.Info("terminated connection due to cancellation or expiration", "session_id", sessionId, "connection_id", connId) +} + func (w *Worker) lastSuccessfulStatusTime() time.Time { lastStatus := w.LastStatusSuccess() if lastStatus == nil { @@ -288,7 +329,7 @@ func (w *Worker) lastSuccessfulStatusTime() time.Time { func (w *Worker) isPastGrace() (bool, time.Time, time.Duration) { t := w.lastSuccessfulStatusTime() - u := w.statusGracePeriod() + u := w.conf.StatusGracePeriodDuration v := time.Since(t) return v > u, t, u } diff --git a/internal/servers/worker/status_test.go b/internal/servers/worker/status_test.go new file mode 100644 index 0000000000..bf743f8343 --- /dev/null +++ b/internal/servers/worker/status_test.go @@ -0,0 +1,56 @@ +package worker + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/hashicorp/boundary/internal/cmd/base" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/require" +) + +func TestWorkerWaitForNextSuccessfulStatusUpdate(t *testing.T) { + t.Parallel() + require := require.New(t) + for _, name := range []string{"ok", "timeout"} { + t.Run(name, func(t *testing.T) { + // As-needed initialization of a mock worker + w := &Worker{ + logger: hclog.New(nil), + lastStatusSuccess: new(atomic.Value), + baseContext: context.Background(), + conf: &Config{ + Server: &base.Server{ + StatusGracePeriodDuration: time.Second * 1, + }, + }, + } + + // This is present in New() + w.lastStatusSuccess.Store((*LastStatusInformation)(nil)) + + var wg sync.WaitGroup + var err error + wg.Add(1) + go func() { + err = w.WaitForNextSuccessfulStatusUpdate() + wg.Done() + }() + + if name == "ok" { + time.Sleep(time.Millisecond * 100) + w.lastStatusSuccess.Store(&LastStatusInformation{StatusTime: time.Now()}) + } + + wg.Wait() + if name == "timeout" { + require.ErrorIs(err, context.DeadlineExceeded) + } else { + require.NoError(err) + } + }) + } +} diff --git a/internal/servers/worker/testing.go b/internal/servers/worker/testing.go index ddd2bc1952..fedaa0f397 100644 --- a/internal/servers/worker/testing.go +++ b/internal/servers/worker/testing.go @@ -176,6 +176,10 @@ type TestWorkerOpts struct { // The logger to use, or one will be created Logger hclog.Logger + + // The amount of time to wait before marking connections as closed when a + // connection cannot be made back to the controller + StatusGracePeriodDuration time.Duration } func NewTestWorker(t *testing.T, opts *TestWorkerOpts) *TestWorker { @@ -219,6 +223,9 @@ func NewTestWorker(t *testing.T, opts *TestWorkerOpts) *TestWorker { }) } + // Initialize status grace period + tw.b.SetStatusGracePeriodDuration(opts.StatusGracePeriodDuration) + if opts.Config.Worker == nil { opts.Config.Worker = new(config.Worker) } @@ -278,10 +285,11 @@ func (tw *TestWorker) AddClusterWorkerMember(t *testing.T, opts *TestWorkerOpts) opts = new(TestWorkerOpts) } nextOpts := &TestWorkerOpts{ - WorkerAuthKms: tw.w.conf.WorkerAuthKms, - Name: opts.Name, - InitialControllers: tw.ControllerAddrs(), - Logger: tw.w.conf.Logger, + WorkerAuthKms: tw.w.conf.WorkerAuthKms, + Name: opts.Name, + InitialControllers: tw.ControllerAddrs(), + Logger: tw.w.conf.Logger, + StatusGracePeriodDuration: opts.StatusGracePeriodDuration, } if opts.Logger != nil { nextOpts.Logger = opts.Logger diff --git a/internal/servers/worker/worker.go b/internal/servers/worker/worker.go index dd62e74d4a..944b243421 100644 --- a/internal/servers/worker/worker.go +++ b/internal/servers/worker/worker.go @@ -28,8 +28,8 @@ type Worker struct { started *ua.Bool controllerStatusConn *atomic.Value - workerStartTime time.Time lastStatusSuccess *atomic.Value + workerStartTime time.Time controllerResolver *atomic.Value diff --git a/internal/session/query.go b/internal/session/query.go index c294185c41..563f3a5e69 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -235,13 +235,13 @@ where ) ` - // connectionsToCloseCte finds connections that are: + // closeDeadConnectionsCte finds connections that are: // // * not closed // * not announced by a given server in its latest update // // and marks them as closed. - connectionsToCloseCte = ` + closeDeadConnectionsCte = ` with -- Find connections that are not closed so we can reference those IDs unclosed_connections as ( @@ -257,7 +257,7 @@ with -- It's not in limbo between when it moved into this state and when -- it started being reported by the worker, which is roughly every -- 2-3 seconds - start_time < now() - interval '10 seconds' + start_time < wt_sub_seconds_from_now(10) ), connections_to_close as ( select public_id @@ -278,5 +278,100 @@ with where public_id in (select public_id from connections_to_close) returning public_id + ` + + // closeConnectionsForDeadServersCte finds connections that are: + // + // * not closed + // * belong to servers that have not reported in within an acceptable + // threshold of time + // + // and marks them as closed. + // + // The query returns the set of servers that have had connections closed + // along with their last update time and the number of connections closed on + // each. + closeConnectionsForDeadServersCte = ` +with + -- Get dead servers, parameterized off of grace period in seconds + dead_servers as ( + select private_id, update_time + from server + where update_time < wt_sub_seconds_from_now($1) + ), + -- Find connections that are not closed so we can reference those IDs + unclosed_connections as ( + select connection_id + from session_connection_state + where + -- It's the current state + end_time is null + and + -- Current state isn't closed state + state in ('authorized', 'connected') + and + -- It's not in limbo between when it moved into this state and when + -- it started being reported by the worker, which is roughly every + -- 2-3 seconds + start_time < wt_sub_seconds_from_now(10) + ), + connections_to_close as ( + select public_id + from session_connection + where + -- Related to the worker that just reported to us + server_id in (select private_id from dead_servers) + and + -- Only unclosed ones + public_id in (select connection_id from unclosed_connections) + ), + closed_connections as ( + update session_connection + set + closed_reason = 'system error' + where + public_id in (select public_id from connections_to_close) + returning public_id, server_id + ) + select + dead_servers.private_id, + dead_servers.update_time, + count(closed_connections.public_id) + from dead_servers + left join closed_connections + on dead_servers.private_id = closed_connections.server_id + group by dead_servers.private_id, dead_servers.update_time + having count(closed_connections.public_id) > 0 + order by dead_servers.private_id + ` + + // shouldCloseConnectionsCte finds connections that are marked as closed in + // the database given a set of connection IDs. They are returned along with + // their associated session ID. + // + // The second parameter is a set of session IDs that we have already + // submitted a session-wide close request for, so sending another change + // request for them would be redundant. + shouldCloseConnectionsCte = ` +with + -- Find connections that are closed so we can reference those IDs + closed_connections as ( + select connection_id + from session_connection_state + where + -- It's the current state + end_time is null + and + -- Current state is closed + state = 'closed' + and + connection_id in (%s) + ) + select public_id, session_id + from session_connection + where + public_id in (select connection_id from closed_connections) + -- Below fmt arg is filled in if there are session IDs to filter against + %s ` ) diff --git a/internal/session/repository_connection.go b/internal/session/repository_connection.go index c145c14baa..80ff88ef3c 100644 --- a/internal/session/repository_connection.go +++ b/internal/session/repository_connection.go @@ -4,11 +4,18 @@ import ( "context" "fmt" "strings" + "time" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/servers" ) +// deadWorkerConnCloseMinGrace is the minimum allowable setting for +// the CloseConnectionsForDeadWorkers method. This is synced with +// the default server liveness setting. +var DeadWorkerConnCloseMinGrace = int(servers.DefaultLiveness.Seconds()) + // LookupConnection will look up a connection in the repository and return the connection // with its states. If the connection is not found, it will return nil, nil, nil. // No options are currently supported. @@ -99,18 +106,16 @@ func (r *Repository) DeleteConnection(ctx context.Context, publicId string, _ .. return rowsDeleted, nil } -// CloseDeadConnectionsOnWorkerReport will run the connectionsToClose CTE to -// look for connections that should be marked closed because they are no longer -// claimed by a server. This does notdetect connections where the server is no -// longer reporting status; that's for a different CTE to be called by a -// heartbeat detector. +// CloseDeadConnectionsForWorker will run closeDeadConnectionsCte to look for +// connections that should be marked closed because they are no longer claimed +// by a server. // // The foundConns input should be the currently-claimed connections; the CTE // uses a NOT IN clause to ensure these are excluded. It is not an error for // this to be empty as the worker could claim no connections; in that case all // connections will immediately transition to closed. -func (r *Repository) CloseDeadConnectionsOnWorkerReport(ctx context.Context, serverId string, foundConns []string) (int, error) { - const op = "session.(Repository).CloseDeadConnectionsOnWorkerReport" +func (r *Repository) CloseDeadConnectionsForWorker(ctx context.Context, serverId string, foundConns []string) (int, error) { + const op = "session.(Repository).CloseDeadConnectionsForWorker" if serverId == "" { return db.NoRowsAffected, errors.New(errors.InvalidParameter, op, "missing server id") } @@ -135,7 +140,7 @@ func (r *Repository) CloseDeadConnectionsOnWorkerReport(ctx context.Context, ser db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { var err error - rowsAffected, err = w.Exec(ctx, fmt.Sprintf(connectionsToCloseCte, publicIdStr), args) + rowsAffected, err = w.Exec(ctx, fmt.Sprintf(closeDeadConnectionsCte, publicIdStr), args) if err != nil { return errors.Wrap(err, op) } @@ -148,6 +153,129 @@ func (r *Repository) CloseDeadConnectionsOnWorkerReport(ctx context.Context, ser return rowsAffected, nil } +type CloseConnectionsForDeadWorkersResult struct { + ServerId string + LastUpdateTime time.Time + NumberConnectionsClosed int +} + +// CloseConnectionsForDeadWorkers will run +// closeConnectionsForDeadServersCte to look for connections that +// should be marked because they are on a server that is no longer +// sending status updates to the controller(s). +// +// The only input to the method is the grace period, in seconds. +func (r *Repository) CloseConnectionsForDeadWorkers(ctx context.Context, gracePeriod int) ([]CloseConnectionsForDeadWorkersResult, error) { + const op = "session.(Repository).CloseConnectionsForDeadWorkers" + if gracePeriod < DeadWorkerConnCloseMinGrace { + return nil, errors.New( + errors.InvalidParameter, op, fmt.Sprintf("gracePeriod must be at least %d seconds", DeadWorkerConnCloseMinGrace)) + } + + results := make([]CloseConnectionsForDeadWorkersResult, 0) + _, err := r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + rows, err := w.Query(ctx, closeConnectionsForDeadServersCte, []interface{}{gracePeriod}) + if err != nil { + return errors.Wrap(err, op) + } + defer rows.Close() + + for rows.Next() { + var ( + serverId string + lastUpdateTime time.Time + numberConnectionsClosed int + ) + + if err := rows.Scan(&serverId, &lastUpdateTime, &numberConnectionsClosed); err != nil { + return errors.Wrap(err, op) + } + + results = append(results, CloseConnectionsForDeadWorkersResult{ + ServerId: serverId, + LastUpdateTime: lastUpdateTime, + NumberConnectionsClosed: numberConnectionsClosed, + }) + } + + return nil + }, + ) + + if err != nil { + return nil, errors.Wrap(err, op) + } + + return results, nil +} + +// ShouldCloseConnectionsOnWorker will run shouldCloseConnectionsCte to look +// for connections that the worker should close because they are currently +// reporting them as open incorrectly. +// +// The foundConns input here is used to filter closed connection states. This +// is further filtered against the filterSessions input, which is expected to +// be a set of sessions we've already submitted close requests for, so adding +// them again would be redundant. +// +// The returned map[string][]string is indexed by session ID. +func (r *Repository) ShouldCloseConnectionsOnWorker(ctx context.Context, foundConns, filterSessions []string) (map[string][]string, error) { + const op = "session.(Repository).ShouldCloseConnectionsOnWorker" + if len(foundConns) < 1 { + return nil, nil // nothing to do + } + + args := make([]interface{}, 0, len(foundConns)+len(filterSessions)) + + // foundConns first + connsParams := make([]string, len(foundConns)) + for i, connId := range foundConns { + connsParams[i] = fmt.Sprintf("$%d", i+1) + args = append(args, connId) + } + connsStr := strings.Join(connsParams, ",") + + // then filterSessions + var sessionsStr string + if len(filterSessions) > 0 { + offset := len(foundConns) + 1 + sessionsParams := make([]string, len(filterSessions)) + for i, sessionId := range filterSessions { + sessionsParams[i] = fmt.Sprintf("$%d", i+offset) + args = append(args, sessionId) + } + + const sessionIdFmtStr = `and session_id not in (%s)` + sessionsStr = fmt.Sprintf(sessionIdFmtStr, strings.Join(sessionsParams, ",")) + } + + rows, err := r.reader.Query( + ctx, + fmt.Sprintf(shouldCloseConnectionsCte, connsStr, sessionsStr), + args, + ) + if err != nil { + return nil, errors.Wrap(err, op) + } + defer rows.Close() + + result := make(map[string][]string) + for rows.Next() { + var connectionId, sessionId string + if err := rows.Scan(&connectionId, &sessionId); err != nil { + return nil, errors.Wrap(err, op) + } + + result[sessionId] = append(result[sessionId], connectionId) + } + + return result, nil +} + func fetchConnectionStates(ctx context.Context, r db.Reader, connectionId string, opt ...db.Option) ([]*ConnectionState, error) { const op = "session.fetchConnectionStates" var states []*ConnectionState diff --git a/internal/session/repository_connection_test.go b/internal/session/repository_connection_test.go index 18454793a5..b0498d4361 100644 --- a/internal/session/repository_connection_test.go +++ b/internal/session/repository_connection_test.go @@ -2,6 +2,7 @@ package session import ( "context" + "fmt" "testing" "time" @@ -10,6 +11,7 @@ import ( "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" + "github.com/hashicorp/boundary/internal/servers" "github.com/hashicorp/shared-secure-libs/strutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -210,7 +212,7 @@ func TestRepository_DeleteConnection(t *testing.T) { } } -func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) { +func TestRepository_CloseDeadConnectionsOnWorker(t *testing.T) { t.Parallel() require, assert := require.New(t), assert.New(t) conn, _ := db.TestSetup(t, "postgres") @@ -284,7 +286,7 @@ func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) { // all connection IDs for worker 1 should be showing as non-closed, and // the ones for worker 2 not advertised should be closed. shouldStayOpen := worker2ConnIds[0:2] - count, err := repo.CloseDeadConnectionsOnWorkerReport(ctx, worker2.GetPrivateId(), shouldStayOpen) + count, err := repo.CloseDeadConnectionsForWorker(ctx, worker2.GetPrivateId(), shouldStayOpen) require.NoError(err) assert.Equal(4, count) @@ -310,7 +312,7 @@ func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) { // Now, advertise none of the connection IDs for worker 2. This is mainly to // test that handling the case where we do not include IDs works properly as // it changes the where clause. - count, err = repo.CloseDeadConnectionsOnWorkerReport(ctx, worker1.GetPrivateId(), nil) + count, err = repo.CloseDeadConnectionsForWorker(ctx, worker1.GetPrivateId(), nil) require.NoError(err) assert.Equal(6, count) @@ -330,3 +332,440 @@ func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) { assert.True(foundClosed != strutil.StrListContains(shouldStayOpen, conn.PublicId)) } } + +func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { + t.Parallel() + require := require.New(t) + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(err) + serversRepo, err := servers.NewRepository(rw, rw, kms) + require.NoError(err) + ctx := context.Background() + + // connection count = 6 * states(authorized, connected, closed = 3) * servers_with_open_connections(3) + numConns := 54 + + // Create four "workers". This is similar to the setup in + // TestRepository_CloseDeadConnectionsOnWorker, but a bit more complex; + // firstly, the last worker will have no connections at all, and we will be + // closing the others in stages to test multiple servers being closed at + // once. + worker1 := TestWorker(t, conn, wrapper, WithServerId("worker1")) + worker2 := TestWorker(t, conn, wrapper, WithServerId("worker2")) + worker3 := TestWorker(t, conn, wrapper, WithServerId("worker3")) + worker4 := TestWorker(t, conn, wrapper, WithServerId("worker4")) + + // Create sessions on the first three, activate, and authorize connections + var worker1ConnIds, worker2ConnIds, worker3ConnIds []string + for i := 0; i < numConns; i++ { + var serverId string + if i%3 == 0 { + serverId = worker1.PrivateId + } else if i%3 == 1 { + serverId = worker2.PrivateId + } else { + serverId = worker3.PrivateId + } + sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithServerId(serverId), WithDbOpts(db.WithSkipVetForWrite(true))) + sess, _, err = repo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, serverId, "worker", []byte("foo")) + require.NoError(err) + c, cs, _, err := repo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) + require.NoError(err) + require.Len(cs, 1) + require.Equal(StatusAuthorized, cs[0].Status) + if i%3 == 0 { + worker1ConnIds = append(worker1ConnIds, c.GetPublicId()) + } else if i%3 == 1 { + worker2ConnIds = append(worker2ConnIds, c.GetPublicId()) + } else { + worker3ConnIds = append(worker3ConnIds, c.GetPublicId()) + } + } + + // Mark a third of the connections connected, a third closed, and leave the + // others authorized. This is just to ensure we have a spread when we test it + // out. + for i, connId := range func() []string { + var s []string + s = append(s, worker1ConnIds...) + s = append(s, worker2ConnIds...) + s = append(s, worker3ConnIds...) + return s + }() { + if i%3 == 0 { + _, cs, err := repo.ConnectConnection(ctx, ConnectWith{ + ConnectionId: connId, + ClientTcpAddress: "127.0.0.1", + ClientTcpPort: 22, + EndpointTcpAddress: "127.0.0.1", + EndpointTcpPort: 22, + }) + require.NoError(err) + require.Len(cs, 2) + var foundAuthorized, foundConnected bool + for _, status := range cs { + if status.Status == StatusAuthorized { + foundAuthorized = true + } + if status.Status == StatusConnected { + foundConnected = true + } + } + require.True(foundAuthorized) + require.True(foundConnected) + } else if i%3 == 1 { + resp, err := repo.CloseConnections(ctx, []CloseWith{ + { + ConnectionId: connId, + ClosedReason: ConnectionCanceled, + }, + }) + require.NoError(err) + require.Len(resp, 1) + cs := resp[0].ConnectionStates + require.Len(cs, 2) + var foundAuthorized, foundClosed bool + for _, status := range cs { + if status.Status == StatusAuthorized { + foundAuthorized = true + } + if status.Status == StatusClosed { + foundClosed = true + } + } + require.True(foundAuthorized) + require.True(foundClosed) + } + } + + // There is a 15 second delay to account for time for the connections to + // transition + time.Sleep(15 * time.Second) + + // updateServer is a helper for updating the update time for our + // servers. The controller is read back so that we can reference + // the most up-to-date fields. + updateServer := func(t *testing.T, w *servers.Server) *servers.Server { + t.Helper() + _, rowsUpdated, err := serversRepo.UpsertServer(ctx, w) + require.NoError(err) + require.Equal(1, rowsUpdated) + servers, err := serversRepo.ListServers(ctx, servers.ServerTypeWorker) + require.NoError(err) + for _, server := range servers { + if server.PrivateId == w.PrivateId { + return server + } + } + + require.FailNowf("server %q not found after updating", w.PrivateId) + // Looks weird but needed to build, as we fail in testify instead + // of returning an error + return nil + } + + // requireConnectionStatus is a helper expecting all connections on a worker + // to be closed. + requireConnectionStatus := func(t *testing.T, connIds []string, expectAllClosed bool) { + t.Helper() + + var conns []*Connection + require.NoError(repo.list(ctx, &conns, "", nil)) + for i, connId := range connIds { + var expected ConnectionStatus + switch { + case expectAllClosed: + expected = StatusClosed + + case i%3 == 0: + expected = StatusConnected + + case i%3 == 1: + expected = StatusClosed + + case i%3 == 2: + expected = StatusAuthorized + } + + _, states, err := repo.LookupConnection(ctx, connId) + require.NoError(err) + require.Equal(expected, states[0].Status, "expected latest status for %q (index %d) to be %v", connId, i, expected) + } + } + + // We need this helper to fix the zone on protobuf timestamps + // versus what gets reported in the + // CloseConnectionsForDeadWorkersResult. + timestampPbAsUTC := func(t *testing.T, tm time.Time) time.Time { + t.Helper() + utcLoc, err := time.LoadLocation("Etc/UTC") + require.NoError(err) + return tm.In(utcLoc) + } + + // Now try some scenarios. + { + // First, test the error/validation case. + result, err := repo.CloseConnectionsForDeadWorkers(ctx, 0) + require.Equal(err, errors.E( + errors.WithCode(errors.InvalidParameter), + errors.WithOp("session.(Repository).CloseConnectionsForDeadWorkers"), + errors.WithMsg(fmt.Sprintf("gracePeriod must be at least %d seconds", DeadWorkerConnCloseMinGrace)), + )) + require.Nil(result) + } + + { + // Now, try the basis, or where all workers are reporting in. + worker1 = updateServer(t, worker1) + worker2 = updateServer(t, worker2) + worker3 = updateServer(t, worker3) + updateServer(t, worker4) // no re-assignment here because we never reference the server again + + result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace) + require.NoError(err) + require.Empty(result) + // Expect appropriate split connection state on worker1 + requireConnectionStatus(t, worker1ConnIds, false) + // Expect appropriate split connection state on worker2 + requireConnectionStatus(t, worker2ConnIds, false) + // Expect appropriate split connection state on worker3 + requireConnectionStatus(t, worker3ConnIds, false) + } + + { + // Now try a zero case - similar to the basis, but only in that no results + // are expected to be returned for workers with no connections, even if + // they are dead. Here, the server with no connections is worker #4. + time.Sleep(time.Second * time.Duration(DeadWorkerConnCloseMinGrace)) + worker1 = updateServer(t, worker1) + worker2 = updateServer(t, worker2) + worker3 = updateServer(t, worker3) + + result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace) + require.NoError(err) + require.Empty(result) + // Expect appropriate split connection state on worker1 + requireConnectionStatus(t, worker1ConnIds, false) + // Expect appropriate split connection state on worker2 + requireConnectionStatus(t, worker2ConnIds, false) + // Expect appropriate split connection state on worker3 + requireConnectionStatus(t, worker3ConnIds, false) + } + + { + // The first induction is letting the first worker "die" by not updating it + // too. All of its authorized and connected connections should be dead. + time.Sleep(time.Second * time.Duration(DeadWorkerConnCloseMinGrace)) + worker2 = updateServer(t, worker2) + worker3 = updateServer(t, worker3) + + result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace) + require.NoError(err) + // Assert that we have one result with the appropriate ID and + // number of connections closed. Due to how things are + require.Equal([]CloseConnectionsForDeadWorkersResult{ + { + ServerId: worker1.PrivateId, + LastUpdateTime: timestampPbAsUTC(t, worker1.UpdateTime.AsTime()), + NumberConnectionsClosed: 12, // 18 per server, with 6 closed already + }, + }, result) + // Expect all connections closed on worker1 + requireConnectionStatus(t, worker1ConnIds, true) + // Expect appropriate split connection state on worker2 + requireConnectionStatus(t, worker2ConnIds, false) + // Expect appropriate split connection state on worker3 + requireConnectionStatus(t, worker3ConnIds, false) + } + + { + // The final case is having the other two workers die. After + // this, we should have all connections closed with the + // appropriate message from the next two servers acted on. + time.Sleep(time.Second * time.Duration(DeadWorkerConnCloseMinGrace)) + + result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace) + require.NoError(err) + // Assert that we have one result with the appropriate ID and number of connections closed. + require.Equal([]CloseConnectionsForDeadWorkersResult{ + { + ServerId: worker2.PrivateId, + LastUpdateTime: timestampPbAsUTC(t, worker2.UpdateTime.AsTime()), + NumberConnectionsClosed: 12, // 18 per server, with 6 closed already + }, + { + ServerId: worker3.PrivateId, + LastUpdateTime: timestampPbAsUTC(t, worker3.UpdateTime.AsTime()), + NumberConnectionsClosed: 12, // 18 per server, with 6 closed already + }, + }, result) + // Expect all connections closed on worker1 + requireConnectionStatus(t, worker1ConnIds, true) + // Expect all connections closed on worker2 + requireConnectionStatus(t, worker2ConnIds, true) + // Expect all connections closed on worker3 + requireConnectionStatus(t, worker3ConnIds, true) + } +} + +func TestRepository_ShouldCloseConnectionsOnWorker(t *testing.T) { + t.Parallel() + require := require.New(t) + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(err) + ctx := context.Background() + numConns := 12 + + // Create a worker, we only need one here as our query is dependent + // on connection and not worker. + worker1 := TestWorker(t, conn, wrapper, WithServerId("worker1")) + + // Create a few sessions on each, activate, and authorize a connection + var connIds []string + sessionConnIds := make(map[string][]string) + for i := 0; i < numConns; i++ { + serverId := worker1.PrivateId + sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithServerId(serverId), WithDbOpts(db.WithSkipVetForWrite(true))) + sessionId := sess.GetPublicId() + sess, _, err = repo.ActivateSession(ctx, sessionId, sess.Version, serverId, "worker", []byte("foo")) + require.NoError(err) + c, cs, _, err := repo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) + require.NoError(err) + require.Len(cs, 1) + require.Equal(StatusAuthorized, cs[0].Status) + connId := c.GetPublicId() + connIds = append(connIds, connId) + sessionConnIds[sessionId] = append(sessionConnIds[sessionId], connId) + } + + // Mark half of the connections connected, close the other half. + for i, connId := range connIds { + if i%2 == 0 { + _, cs, err := repo.ConnectConnection(ctx, ConnectWith{ + ConnectionId: connId, + ClientTcpAddress: "127.0.0.1", + ClientTcpPort: 22, + EndpointTcpAddress: "127.0.0.1", + EndpointTcpPort: 22, + }) + require.NoError(err) + require.Len(cs, 2) + var foundAuthorized, foundConnected bool + for _, status := range cs { + if status.Status == StatusAuthorized { + foundAuthorized = true + } + if status.Status == StatusConnected { + foundConnected = true + } + } + require.True(foundAuthorized) + require.True(foundConnected) + } else { + resp, err := repo.CloseConnections(ctx, []CloseWith{ + { + ConnectionId: connId, + ClosedReason: ConnectionCanceled, + }, + }) + require.NoError(err) + require.Len(resp, 1) + cs := resp[0].ConnectionStates + require.Len(cs, 2) + var foundAuthorized, foundClosed bool + for _, status := range cs { + if status.Status == StatusAuthorized { + foundAuthorized = true + } + if status.Status == StatusClosed { + foundClosed = true + } + } + require.True(foundAuthorized) + require.True(foundClosed) + } + } + + // There is a 10 second delay to account for time for the connections to + // transition + time.Sleep(15 * time.Second) + + // Now we try some scenarios. + { + // First test an empty set. + result, err := repo.ShouldCloseConnectionsOnWorker(ctx, nil, nil) + require.NoError(err) + require.Zero(result, "should be empty when no connections are supplied") + } + + { + // Here we pass in all of our connections without a filter on + // session. This should return half of the connections - the ones + // that we marked as closed. + // + // Create a copy of our session map with the sessions that have + // closed connections taken out. + expectedSessionConnIds := make(map[string][]string) + for sessionId, connIds := range sessionConnIds { + for _, connId := range connIds { + if testIsConnectionClosed(ctx, t, repo, connId) { + expectedSessionConnIds[sessionId] = append(expectedSessionConnIds[sessionId], connId) + } + } + } + + // Send query, use all connections w/o a filter on sessions. + actualSessionConnIds, err := repo.ShouldCloseConnectionsOnWorker(ctx, connIds, nil) + require.NoError(err) + require.Equal(expectedSessionConnIds, actualSessionConnIds) + } + + { + // Finally, add a session filter. We do this by just alternating + // the session IDs we want to filter on. + expectedSessionConnIds := make(map[string][]string) + var filterSessionIds []string + var filterSession bool + for sessionId, connIds := range sessionConnIds { + for _, connId := range connIds { + if testIsConnectionClosed(ctx, t, repo, connId) { + if !filterSession { + expectedSessionConnIds[sessionId] = append(expectedSessionConnIds[sessionId], connId) + } else { + filterSessionIds = append(filterSessionIds, sessionId) + } + + // Toggle filterSession here (instead of just outer session + // loop) so that we aren't just lining up on + // connected/disconnected connections. + filterSession = !filterSession + } + } + } + + // Send query with the session filter. + actualSessionConnIds, err := repo.ShouldCloseConnectionsOnWorker(ctx, connIds, filterSessionIds) + require.NoError(err) + require.Equal(expectedSessionConnIds, actualSessionConnIds) + } +} + +func testIsConnectionClosed(ctx context.Context, t *testing.T, repo *Repository, connId string) bool { + require := require.New(t) + _, states, err := repo.LookupConnection(ctx, connId) + require.NoError(err) + // Use first state as this LookupConnections returns ordered by + // start time, descending. + return states[0].Status == StatusClosed +} diff --git a/internal/tests/cluster/session_cleanup_test.go b/internal/tests/cluster/session_cleanup_test.go index ab506a2df1..54140d93a1 100644 --- a/internal/tests/cluster/session_cleanup_test.go +++ b/internal/tests/cluster/session_cleanup_test.go @@ -37,9 +37,9 @@ import ( ) const ( - testSendRecvSendMax = 60 - defaultGracePeriod = time.Second * 30 - expectConnectionStateOnWorkerTimeout = defaultGracePeriod * 2 + defaultGracePeriod = time.Second * 15 + expectConnectionStateOnControllerTimeout = time.Minute * 2 + expectConnectionStateOnWorkerTimeout = defaultGracePeriod * 2 // This is the interval that we check states on in the worker. It // needs to be particularly granular to ensure that we allow for @@ -55,227 +55,414 @@ const ( expectConnectionStateOnWorkerInterval = time.Millisecond * 100 ) -func TestWorkerSessionCleanup(t *testing.T) { - require := require.New(t) - logger := hclog.New(&hclog.LoggerOptions{ - Level: hclog.Trace, - }) - - conf, err := config.DevController() - require.NoError(err) - - pl, err := net.Listen("tcp", "localhost:0") - require.NoError(err) - c1 := controller.NewTestController(t, &controller.TestControllerOpts{ - Config: conf, - InitialResourcesSuffix: "1234567890", - Logger: logger.Named("c1"), - PublicClusterAddr: pl.Addr().String(), - }) - defer c1.Shutdown() - - expectWorkers(t, c1) - - // Wire up the testing proxies - require.Len(c1.ClusterAddrs(), 1) - proxy, err := dawdle.NewProxy("tcp", "", c1.ClusterAddrs()[0], - dawdle.WithListener(pl), - dawdle.WithRbufSize(512), - dawdle.WithWbufSize(512), - ) - require.NoError(err) - defer proxy.Close() - require.NotEmpty(t, proxy.ListenerAddr()) - - w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{ - WorkerAuthKms: c1.Config().WorkerAuthKms, - InitialControllers: []string{proxy.ListenerAddr()}, - Logger: logger.Named("w1"), - }) - defer w1.Shutdown() +// timeoutBurdenType details our "burden cases" for the session +// cleanup tests. +// +// There are 3 burden cases: +// +// * default: This case simulates normal default operation where both +// worker and controller generally are timing out connections at +// generally the same interval. In reality, this is not necessarily +// going to be the case, but it's hard to test individual cases when +// both settings are the same. +// +// * worker: This case assumes the worker is the source of truth for +// controller state. Here, the controller's grace period is +// increased to a high factor over the default to ensure that the +// worker is managing the lifecycle of a connection and will properly +// unclaim it closed once the connection resumes, ensuring the +// connection is marked as closed on the worker. +// +// * controller: Here, the controller is the one doing the work. The +// connection will be open on the worker until status checks resume +// from the worker. At this point, the controller will request the +// status change on the worker, physically closing the connection +// there. +type timeoutBurdenType string - time.Sleep(10 * time.Second) - expectWorkers(t, c1, w1) +const ( + timeoutBurdenTypeDefault timeoutBurdenType = "default" + timeoutBurdenTypeWorker timeoutBurdenType = "worker" + timeoutBurdenTypeController timeoutBurdenType = "controller" +) - // Use an independent context for test things that take a context so - // that we aren't tied to any timeouts in the controller, etc. This - // can interfere with some of the test operations. - ctx := context.Background() +var timeoutBurdenCases = []timeoutBurdenType{timeoutBurdenTypeDefault, timeoutBurdenTypeWorker, timeoutBurdenTypeController} - // Connect target - client := c1.Client() - client.SetToken(c1.Token().Token) - tcl := targets.NewClient(client) - tgt, err := tcl.Read(ctx, "ttcp_1234567890") - require.NoError(err) - require.NotNil(tgt) +func controllerGracePeriod(ty timeoutBurdenType) time.Duration { + if ty == timeoutBurdenTypeWorker { + return defaultGracePeriod * 10 + } - // Create test server, update default port on target - ts := newTestTcpServer(t, logger) - require.NotNil(t, ts) - defer ts.Close() - tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port())) - require.NoError(err) - require.NotNil(tgt) + return defaultGracePeriod +} - // Authorize and connect - sess := newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") - sConn := sess.Connect(ctx, t, logger) +func workerGracePeriod(ty timeoutBurdenType) time.Duration { + if ty == timeoutBurdenTypeController { + return defaultGracePeriod * 10 + } - // Run initial send/receive test, make sure things are working - sConn.TestSendRecvAll(t) + return defaultGracePeriod +} - // Kill the link - proxy.Pause() +// TestWorkerSessionCleanup is the main test for session cleanup, and +// dispatches to the individual subtests. +func TestWorkerSessionCleanup(t *testing.T) { + t.Parallel() + for _, burdenCase := range timeoutBurdenCases { + burdenCase := burdenCase + t.Run(string(burdenCase), func(t *testing.T) { + t.Parallel() + t.Run("single_controller", testWorkerSessionCleanupSingle(burdenCase)) + t.Run("multi_controller", testWorkerSessionCleanupMulti(burdenCase)) + }) + } +} - // Run again, ensure connection is dead - sConn.TestSendRecvFail(t) +func testWorkerSessionCleanupSingle(burdenCase timeoutBurdenType) func(t *testing.T) { + return func(t *testing.T) { + t.Parallel() + require := require.New(t) + logger := hclog.New(&hclog.LoggerOptions{ + Name: t.Name(), + Level: hclog.Trace, + }) + + conf, err := config.DevController() + require.NoError(err) + + pl, err := net.Listen("tcp", "localhost:0") + require.NoError(err) + c1 := controller.NewTestController(t, &controller.TestControllerOpts{ + Config: conf, + InitialResourcesSuffix: "1234567890", + Logger: logger.Named("c1"), + PublicClusterAddr: pl.Addr().String(), + StatusGracePeriodDuration: controllerGracePeriod(burdenCase), + }) + defer c1.Shutdown() + + expectWorkers(t, c1) + + // Wire up the testing proxies + require.Len(c1.ClusterAddrs(), 1) + proxy, err := dawdle.NewProxy("tcp", "", c1.ClusterAddrs()[0], + dawdle.WithListener(pl), + dawdle.WithRbufSize(256), + dawdle.WithWbufSize(256), + ) + require.NoError(err) + defer proxy.Close() + require.NotEmpty(t, proxy.ListenerAddr()) + + w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{ + WorkerAuthKms: c1.Config().WorkerAuthKms, + InitialControllers: []string{proxy.ListenerAddr()}, + Logger: logger.Named("w1"), + StatusGracePeriodDuration: workerGracePeriod(burdenCase), + }) + defer w1.Shutdown() + + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + require.NoError(err) + err = c1.WaitForNextWorkerStatusUpdate(w1.Name()) + require.NoError(err) + expectWorkers(t, c1, w1) + + // Use an independent context for test things that take a context so + // that we aren't tied to any timeouts in the controller, etc. This + // can interfere with some of the test operations. + ctx := context.Background() + + // Connect target + client := c1.Client() + client.SetToken(c1.Token().Token) + tcl := targets.NewClient(client) + tgt, err := tcl.Read(ctx, "ttcp_1234567890") + require.NoError(err) + require.NotNil(tgt) + + // Create test server, update default port on target + ts := newTestTcpServer(t, logger) + require.NotNil(t, ts) + defer ts.Close() + tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port()), targets.WithSessionConnectionLimit(-1)) + require.NoError(err) + require.NotNil(tgt) + + // Authorize and connect + sess := newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") + sConn := sess.Connect(ctx, t) + + // Run initial send/receive test, make sure things are working + logger.Debug("running initial send/recv test") + sConn.TestSendRecvAll(t) + + // Kill the link + logger.Debug("pausing controller/worker link") + proxy.Pause() + + // Wait for failure connection state (depends on burden case) + switch burdenCase { + case timeoutBurdenTypeWorker: + // Wait on worker, then check controller + sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusConnected) + + case timeoutBurdenTypeController: + // Wait on controller, then check worker + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed) + sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusConnected) + + default: + // Should be closed on both worker and controller. Wait on + // worker then check controller. + sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed) + } - // Assert we have no connections left (should be default behavior) - sess.TestNoConnectionsLeft(t) + // Run send/receive test again to check expected connection-level + // behavior + if burdenCase == timeoutBurdenTypeController { + // Burden on controller, should be successful until connection + // resumes + sConn.TestSendRecvAll(t) + } else { + // Connection should die in other cases + sConn.TestSendRecvFail(t) + } - // Assert connection has been removed from the local worker state - sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) + // Resume the connection, and reconnect. + logger.Debug("resuming controller/worker link") + proxy.Resume() + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + require.NoError(err) + + // Do something post-reconnect depending on burden case. Note in + // the default case, both worker and controller should be + // relatively in sync, so we don't worry about these + // post-reconnection assertions. + switch burdenCase { + case timeoutBurdenTypeWorker: + // If we are expecting the worker to be the source of truth of + // a connection status, ensure that our old session's + // connections are actually closed now that the worker is + // properly reporting in again. + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed) + + case timeoutBurdenTypeController: + // If we are expecting the controller to be the source of + // truth, the connection should now be forcibly closed after + // the worker gets a status change request back. + sConn.TestSendRecvFail(t) + } - // Resume the connection, and reconnect. - proxy.Resume() - time.Sleep(time.Second * 10) // Sleep to wait for worker to report back as healthy - sess = newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") // re-assign, other connection will close in t.Cleanup() - sConn = sess.Connect(ctx, t, logger) - sConn.TestSendRecvAll(t) + // Proceed with new connection test + logger.Debug("connecting to new session after resuming controller/worker link") + sess = newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") // re-assign, other connection will close in t.Cleanup() + sConn = sess.Connect(ctx, t) + sConn.TestSendRecvAll(t) + } } -func TestWorkerSessionCleanupMultiController(t *testing.T) { - require := require.New(t) - logger := hclog.New(&hclog.LoggerOptions{ - Level: hclog.Trace, - }) - - // ****************** - // ** Controller 1 ** - // ****************** - conf1, err := config.DevController() - require.NoError(err) +func testWorkerSessionCleanupMulti(burdenCase timeoutBurdenType) func(t *testing.T) { + return func(t *testing.T) { + t.Parallel() + require := require.New(t) + logger := hclog.New(&hclog.LoggerOptions{ + Name: t.Name(), + Level: hclog.Trace, + }) + + // ****************** + // ** Controller 1 ** + // ****************** + conf1, err := config.DevController() + require.NoError(err) + + pl1, err := net.Listen("tcp", "localhost:0") + require.NoError(err) + c1 := controller.NewTestController(t, &controller.TestControllerOpts{ + Config: conf1, + InitialResourcesSuffix: "1234567890", + Logger: logger.Named("c1"), + PublicClusterAddr: pl1.Addr().String(), + StatusGracePeriodDuration: controllerGracePeriod(burdenCase), + }) + defer c1.Shutdown() + + // ****************** + // ** Controller 2 ** + // ****************** + pl2, err := net.Listen("tcp", "localhost:0") + require.NoError(err) + c2 := c1.AddClusterControllerMember(t, &controller.TestControllerOpts{ + Logger: logger.Named("c2"), + PublicClusterAddr: pl2.Addr().String(), + StatusGracePeriodDuration: controllerGracePeriod(burdenCase), + }) + defer c2.Shutdown() + expectWorkers(t, c1) + expectWorkers(t, c2) + + // ************* + // ** Proxy 1 ** + // ************* + require.Len(c1.ClusterAddrs(), 1) + p1, err := dawdle.NewProxy("tcp", "", c1.ClusterAddrs()[0], + dawdle.WithListener(pl1), + dawdle.WithRbufSize(256), + dawdle.WithWbufSize(256), + ) + require.NoError(err) + defer p1.Close() + require.NotEmpty(t, p1.ListenerAddr()) + + // ************* + // ** Proxy 2 ** + // ************* + require.Len(c2.ClusterAddrs(), 1) + p2, err := dawdle.NewProxy("tcp", "", c2.ClusterAddrs()[0], + dawdle.WithListener(pl2), + dawdle.WithRbufSize(256), + dawdle.WithWbufSize(256), + ) + require.NoError(err) + defer p2.Close() + require.NotEmpty(t, p2.ListenerAddr()) + + // ************ + // ** Worker ** + // ************ + w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{ + WorkerAuthKms: c1.Config().WorkerAuthKms, + InitialControllers: []string{p1.ListenerAddr(), p2.ListenerAddr()}, + Logger: logger.Named("w1"), + StatusGracePeriodDuration: workerGracePeriod(burdenCase), + }) + defer w1.Shutdown() + + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + require.NoError(err) + err = c1.WaitForNextWorkerStatusUpdate(w1.Name()) + require.NoError(err) + err = c2.WaitForNextWorkerStatusUpdate(w1.Name()) + require.NoError(err) + expectWorkers(t, c1, w1) + expectWorkers(t, c2, w1) + + // Use an independent context for test things that take a context so + // that we aren't tied to any timeouts in the controller, etc. This + // can interfere with some of the test operations. + ctx := context.Background() + + // Connect target + client := c1.Client() + client.SetToken(c1.Token().Token) + tcl := targets.NewClient(client) + tgt, err := tcl.Read(ctx, "ttcp_1234567890") + require.NoError(err) + require.NotNil(tgt) + + // Create test server, update default port on target + ts := newTestTcpServer(t, logger) + require.NotNil(ts) + defer ts.Close() + tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port()), targets.WithSessionConnectionLimit(-1)) + require.NoError(err) + require.NotNil(tgt) + + // Authorize and connect + sess := newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") + sConn := sess.Connect(ctx, t) + + // Run initial send/receive test, make sure things are working + logger.Debug("running initial send/recv test") + sConn.TestSendRecvAll(t) + + // Kill connection to first controller, and run test again, should + // pass, deferring to other controller. Wait for the next + // successful status report to ensure this. + logger.Debug("pausing link to controller #1") + p1.Pause() + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + require.NoError(err) + sConn.TestSendRecvAll(t) + + // Resume first controller, pause second. This one should work too. + logger.Debug("pausing link to controller #2, resuming #1") + p1.Resume() + p2.Pause() + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + require.NoError(err) + sConn.TestSendRecvAll(t) + + // Kill the first controller connection again. This one should fail + // due to lack of any connection. + logger.Debug("pausing link to controller #1 again, both connections should be offline") + p1.Pause() + + // Wait for failure connection state (depends on burden case) + switch burdenCase { + case timeoutBurdenTypeWorker: + // Wait on worker, then check controller + sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusConnected) + + case timeoutBurdenTypeController: + // Wait on controller, then check worker + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed) + sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusConnected) + + default: + // Should be closed on both worker and controller. Wait on + // worker then check controller. + sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed) + } - pl1, err := net.Listen("tcp", "localhost:0") - require.NoError(err) - c1 := controller.NewTestController(t, &controller.TestControllerOpts{ - Config: conf1, - InitialResourcesSuffix: "1234567890", - Logger: logger.Named("c1"), - PublicClusterAddr: pl1.Addr().String(), - }) - defer c1.Shutdown() + // Run send/receive test again to check expected connection-level + // behavior + if burdenCase == timeoutBurdenTypeController { + // Burden on controller, should be successful until connection + // resumes + sConn.TestSendRecvAll(t) + } else { + // Connection should die in other cases + sConn.TestSendRecvFail(t) + } - // ****************** - // ** Controller 2 ** - // ****************** - pl2, err := net.Listen("tcp", "localhost:0") - require.NoError(err) - c2 := c1.AddClusterControllerMember(t, &controller.TestControllerOpts{ - Logger: c1.Config().Logger.ResetNamed("c2"), - PublicClusterAddr: pl2.Addr().String(), - }) - defer c2.Shutdown() - expectWorkers(t, c1) - expectWorkers(t, c2) - - // ************* - // ** Proxy 1 ** - // ************* - require.Len(c1.ClusterAddrs(), 1) - p1, err := dawdle.NewProxy("tcp", "", c1.ClusterAddrs()[0], - dawdle.WithListener(pl1), - dawdle.WithRbufSize(512), - dawdle.WithWbufSize(512), - ) - require.NoError(err) - defer p1.Close() - require.NotEmpty(t, p1.ListenerAddr()) - - // ************* - // ** Proxy 2 ** - // ************* - require.Len(c2.ClusterAddrs(), 1) - p2, err := dawdle.NewProxy("tcp", "", c2.ClusterAddrs()[0], - dawdle.WithListener(pl2), - dawdle.WithRbufSize(512), - dawdle.WithWbufSize(512), - ) - require.NoError(err) - defer p2.Close() - require.NotEmpty(t, p2.ListenerAddr()) - - // ************ - // ** Worker ** - // ************ - w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{ - WorkerAuthKms: c1.Config().WorkerAuthKms, - InitialControllers: []string{p1.ListenerAddr(), p2.ListenerAddr()}, - Logger: logger.Named("w1"), - }) - defer w1.Shutdown() - - time.Sleep(10 * time.Second) - expectWorkers(t, c1, w1) - expectWorkers(t, c2, w1) - - // Use an independent context for test things that take a context so - // that we aren't tied to any timeouts in the controller, etc. This - // can interfere with some of the test operations. - ctx := context.Background() - - // Connect target - client := c1.Client() - client.SetToken(c1.Token().Token) - tcl := targets.NewClient(client) - tgt, err := tcl.Read(ctx, "ttcp_1234567890") - require.NoError(err) - require.NotNil(tgt) + // Finally resume both, try again. Should behave as per normal. + logger.Debug("resuming connections to both controllers") + p1.Resume() + p2.Resume() + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + require.NoError(err) + + // Do something post-reconnect depending on burden case. Note in + // the default case, both worker and controller should be + // relatively in sync, so we don't worry about these + // post-reconnection assertions. + switch burdenCase { + case timeoutBurdenTypeWorker: + // If we are expecting the worker to be the source of truth of + // a connection status, ensure that our old session's + // connections are actually closed now that the worker is + // properly reporting in again. + sess.ExpectConnectionStateOnController(ctx, t, c1, session.StatusClosed) + + case timeoutBurdenTypeController: + // If we are expecting the controller to be the source of + // truth, the connection should now be forcibly closed after + // the worker gets a status change request back. + sConn.TestSendRecvFail(t) + } - // Create test server, update default port on target - ts := newTestTcpServer(t, logger) - require.NotNil(ts) - defer ts.Close() - tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port())) - require.NoError(err) - require.NotNil(tgt) - - // Authorize and connect - sess := newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") - sConn := sess.Connect(ctx, t, logger) - - // Run initial send/receive test, make sure things are working - sConn.TestSendRecvAll(t) - - // Kill connection to first controller, and run test again, should - // pass, deferring to other controller. - p1.Pause() - sConn.TestSendRecvAll(t) - - // Resume first controller, pause second. This one should work too. - p1.Resume() - p2.Pause() - sConn.TestSendRecvAll(t) - - // Kill the first controller connection again. This one should fail - // due to lack of any connection. - p1.Pause() - sConn.TestSendRecvFail(t) - - // Assert we have no connections left (should be default behavior) - sess.TestNoConnectionsLeft(t) - - // Assert connection has been removed from the local worker state - sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) - - // Finally resume both, try again. Should behave as per normal. - p1.Resume() - p2.Resume() - time.Sleep(time.Second * 10) // Sleep to wait for worker to report back as healthy - sess = newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") // re-assign, other connection will close in t.Cleanup() - sConn = sess.Connect(ctx, t, logger) - sConn.TestSendRecvAll(t) + // Proceed with new connection test + logger.Debug("connecting to new session after resuming controller/worker link") + sess = newTestSession(ctx, t, logger, tcl, "ttcp_1234567890") // re-assign, other connection will close in t.Cleanup() + sConn = sess.Connect(ctx, t) + sConn.TestSendRecvAll(t) + } } // testSession represents an authorized session. @@ -392,10 +579,73 @@ func (s *testSession) connect(ctx context.Context, t *testing.T) net.Conn { return websocket.NetConn(ctx, conn, websocket.MessageBinary) } -// TestNoConnectionsLeft asserts that there are no connections left. -func (s *testSession) TestNoConnectionsLeft(t *testing.T) { +// ExpectConnectionStateOnController waits until all connections in a +// session have transitioned to a particular state on the controller. +func (s *testSession) ExpectConnectionStateOnController( + ctx context.Context, + t *testing.T, + tc *controller.TestController, + expectState session.ConnectionStatus, +) { t.Helper() - require.Zero(t, s.connectionsLeft) + require := require.New(t) + assert := assert.New(t) + + ctx, cancel := context.WithTimeout(ctx, expectConnectionStateOnControllerTimeout) + defer cancel() + + // This is just for initialization of the actual state set. + const sessionStatusUnknown session.ConnectionStatus = "unknown" + + // Get all connections for the session on the controller. + sessionRepo, err := tc.Controller().SessionRepoFn() + require.NoError(err) + + conns, err := sessionRepo.ListConnectionsBySessionId(ctx, s.sessionId) + require.NoError(err) + // To avoid misleading passing tests, we require this test be used + // with sessions with connections.. + require.Greater(len(conns), 0, "should have at least one connection") + + // Make a set of states, 1 per connection + actualStates := make([]session.ConnectionStatus, len(conns)) + for i := range actualStates { + actualStates[i] = sessionStatusUnknown + } + + // Make expect set for comparison + expectStates := make([]session.ConnectionStatus, len(conns)) + for i := range expectStates { + expectStates[i] = expectState + } + + for { + if ctx.Err() != nil { + break + } + + for i, conn := range conns { + _, states, err := sessionRepo.LookupConnection(ctx, conn.PublicId, nil) + require.NoError(err) + // Look at the first state in the returned list, which will + // be the most recent state. + actualStates[i] = states[0].Status + } + + if reflect.DeepEqual(expectStates, actualStates) { + break + } + + time.Sleep(time.Second) + } + + // "non-fatal" assert here, so that we can surface both timeouts + // and invalid state + assert.NoError(ctx.Err()) + + // Assert + require.Equal(expectStates, actualStates) + s.logger.Debug("successfully asserted all connection states on controller", "expected_states", expectStates, "actual_states", actualStates) } // ExpectConnectionStateOnWorker waits until all connections in a @@ -491,7 +741,6 @@ type testSessionConnection struct { func (s *testSession) Connect( ctx context.Context, t *testing.T, // Just to add cleanup - logger hclog.Logger, ) *testSessionConnection { t.Helper() require := require.New(t) @@ -504,7 +753,7 @@ func (s *testSession) Connect( return &testSessionConnection{ conn: conn, - logger: logger, + logger: s.logger, } } @@ -518,6 +767,12 @@ func (s *testSession) Connect( func (c *testSessionConnection) testSendRecv(t *testing.T) bool { t.Helper() require := require.New(t) + + // This is a fairly arbitrary value, as the send/recv is + // instantaneous. The main key here is just to make sure that we do + // it a reasonable amount of times to know the connection is + // stable. + const testSendRecvSendMax = 100 for i := uint32(0); i < testSendRecvSendMax; i++ { // Shuttle over the sequence number as base64. err := binary.Write(c.conn, binary.LittleEndian, i) @@ -536,7 +791,7 @@ func (c *testSessionConnection) testSendRecv(t *testing.T) bool { var j uint32 err = binary.Read(c.conn, binary.LittleEndian, &j) if err != nil { - c.logger.Debug("received error during read", "err", err) + c.logger.Debug("received error during read", "err", err, "num_successfully_sent", i) if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, websocket.CloseError{Code: websocket.StatusPolicyViolation, Reason: "timed out"}) { @@ -549,9 +804,10 @@ func (c *testSessionConnection) testSendRecv(t *testing.T) bool { require.Equal(j, i) // Sleep 1s - time.Sleep(time.Second) + // time.Sleep(time.Second) } + c.logger.Debug("finished send/recv successfully", "num_successfully_sent", testSendRecvSendMax) return true } @@ -560,6 +816,7 @@ func (c *testSessionConnection) testSendRecv(t *testing.T) bool { func (c *testSessionConnection) TestSendRecvAll(t *testing.T) { t.Helper() require.True(t, c.testSendRecv(t)) + c.logger.Debug("successfully asserted send/recv as passing") } // TestSendRecvFail asserts that we were able to send/recv all pings @@ -567,6 +824,7 @@ func (c *testSessionConnection) TestSendRecvAll(t *testing.T) { func (c *testSessionConnection) TestSendRecvFail(t *testing.T) { t.Helper() require.False(t, c.testSendRecv(t)) + c.logger.Debug("successfully asserted send/recv as failing") } type testTcpServer struct {