From bc409bc7d4bc89072f2d48b2f7c1c646ae10d02c Mon Sep 17 00:00:00 2001 From: Todd Date: Tue, 17 Jan 2023 09:54:01 -0700 Subject: [PATCH] Graceful shutdown blocks on running proxy handler connections (#2789) --- .../daemon/cluster/handlers/worker_service.go | 2 +- internal/daemon/common/worker_list.go | 10 +++++ internal/daemon/worker/addressreceiver.go | 8 ---- internal/daemon/worker/handler.go | 2 + .../daemon/worker/proxy/proxy_conn_tracker.go | 40 ++++++++++++++++++ .../worker/proxy/proxy_conn_tracker_test.go | 41 +++++++++++++++++++ internal/daemon/worker/worker.go | 38 ++++++++--------- internal/server/options.go | 8 ++++ internal/server/options_test.go | 8 ++++ 9 files changed, 129 insertions(+), 28 deletions(-) create mode 100644 internal/daemon/worker/proxy/proxy_conn_tracker.go create mode 100644 internal/daemon/worker/proxy/proxy_conn_tracker_test.go diff --git a/internal/daemon/cluster/handlers/worker_service.go b/internal/daemon/cluster/handlers/worker_service.go index aa01c83f4c..83d7b8ecc6 100644 --- a/internal/daemon/cluster/handlers/worker_service.go +++ b/internal/daemon/cluster/handlers/worker_service.go @@ -123,7 +123,7 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques if wStat.OperationalState == "" { // If this is an older worker (pre 0.11), it will not have ReleaseVersion and we'll default to active. - // Otherwise, default to Uknown. + // Otherwise, default to Unknown. if wStat.ReleaseVersion == "" { wStat.OperationalState = server.ActiveOperationalState.String() } else { diff --git a/internal/daemon/common/worker_list.go b/internal/daemon/common/worker_list.go index ab7bbd2990..ff2a923bee 100644 --- a/internal/daemon/common/worker_list.go +++ b/internal/daemon/common/worker_list.go @@ -24,6 +24,16 @@ func (w WorkerList) Addresses() []string { return ret } +// PublicIds converts the slice of workers to a slice of public ids of those +// workers. +func (w WorkerList) PublicIds() []string { + ret := make([]string, 0, len(w)) + for _, worker := range w { + ret = append(ret, worker.GetPublicId()) + } + return ret +} + // workerInfos converts the slice of workers to a slice of their workerInfo protos func (w WorkerList) WorkerInfos() []*pb.WorkerInfo { ret := make([]*pb.WorkerInfo, 0, len(w)) diff --git a/internal/daemon/worker/addressreceiver.go b/internal/daemon/worker/addressreceiver.go index 68a4636b69..6444bb596a 100644 --- a/internal/daemon/worker/addressreceiver.go +++ b/internal/daemon/worker/addressreceiver.go @@ -1,18 +1,10 @@ package worker import ( - "context" - "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" ) -var extraAddressReceivers = noopAddressReceivers - -func noopAddressReceivers(context.Context, *Worker) ([]addressReceiver, error) { - return nil, nil -} - type receiverType uint const ( diff --git a/internal/daemon/worker/handler.go b/internal/daemon/worker/handler.go index 23a256e18b..7f652c38be 100644 --- a/internal/daemon/worker/handler.go +++ b/internal/daemon/worker/handler.go @@ -48,10 +48,12 @@ func (w *Worker) handler(props HandlerProperties, sm session.Manager) (http.Hand // Create the muxer to handle the actual endpoints mux := http.NewServeMux() + var h http.Handler h, err := w.handleProxy(props.ListenerConfig, sm) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } + h = proxyHandlers.ProxyHandlerCounter(h) mux.Handle("/v1/proxy", metric.InstrumentWebsocketWrapper(h)) genericWrappedHandler := w.wrapGenericHandler(mux, props) diff --git a/internal/daemon/worker/proxy/proxy_conn_tracker.go b/internal/daemon/worker/proxy/proxy_conn_tracker.go new file mode 100644 index 0000000000..94a7055166 --- /dev/null +++ b/internal/daemon/worker/proxy/proxy_conn_tracker.go @@ -0,0 +1,40 @@ +package proxy + +import ( + "context" + "net/http" + "sync/atomic" +) + +// ProxyState contains the current state of proxies in this process. +var ProxyState proxyState + +type proxyState struct { + proxyCount atomic.Int64 +} + +// CurrentProxiedConnections returns the current number of ongoing proxied +// connections which are currently running the Handler's ProxyConnFn. +func (p *proxyState) CurrentProxiedConnections() int64 { + return p.proxyCount.Load() +} + +// HttpHandlerCounter records how many requests are currently running in the +// wrapped Handler. This should be used for handlers that serve proxied traffic. +func ProxyHandlerCounter(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + ProxyState.proxyCount.Add(1) + defer ProxyState.proxyCount.Add(-1) + h.ServeHTTP(rw, r) + }) +} + +// proxyConnFnCounter wraps a ProxyState and keeps the proxyCount incremented +// while it runs. +func proxyConnFnCounter(fn ProxyConnFn) ProxyConnFn { + return func(ctx context.Context) { + ProxyState.proxyCount.Add(1) + defer ProxyState.proxyCount.Add(-1) + fn(ctx) + } +} diff --git a/internal/daemon/worker/proxy/proxy_conn_tracker_test.go b/internal/daemon/worker/proxy/proxy_conn_tracker_test.go new file mode 100644 index 0000000000..dadc68e2df --- /dev/null +++ b/internal/daemon/worker/proxy/proxy_conn_tracker_test.go @@ -0,0 +1,41 @@ +package proxy + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProxyStateHelpers(t *testing.T) { + old := ProxyState + ProxyState = proxyState{} + t.Cleanup(func() { + ProxyState = old + }) + + t.Run("proxyConnFnCounter", func(t *testing.T) { + assert.EqualValues(t, 0, ProxyState.CurrentProxiedConnections()) + proxyConnFnCounter(func(context.Context) { + assert.EqualValues(t, 1, ProxyState.CurrentProxiedConnections()) + })(context.Background()) + assert.EqualValues(t, 0, ProxyState.CurrentProxiedConnections()) + }) + + t.Run("ProxyHandlerCounter", func(t *testing.T) { + assert.EqualValues(t, 0, ProxyState.CurrentProxiedConnections()) + var handlerRan bool + h := ProxyHandlerCounter(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + handlerRan = true + assert.EqualValues(t, 1, ProxyState.CurrentProxiedConnections()) + })) + assert.EqualValues(t, 0, ProxyState.CurrentProxiedConnections()) + req, err := http.NewRequest("GET", "test/path", nil) + require.NoError(t, err) + h.ServeHTTP(httptest.NewRecorder(), req) + require.True(t, handlerRan) + }) +} diff --git a/internal/daemon/worker/worker.go b/internal/daemon/worker/worker.go index 0bcef2a936..2c08855a79 100644 --- a/internal/daemon/worker/worker.go +++ b/internal/daemon/worker/worker.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/boundary/internal/daemon/cluster" "github.com/hashicorp/boundary/internal/daemon/worker/common" "github.com/hashicorp/boundary/internal/daemon/worker/internal/metric" + "github.com/hashicorp/boundary/internal/daemon/worker/proxy" "github.com/hashicorp/boundary/internal/daemon/worker/session" "github.com/hashicorp/boundary/internal/errors" pb "github.com/hashicorp/boundary/internal/gen/controller/servers" @@ -469,31 +470,30 @@ func (w *Worker) Start() error { return nil } -func (w *Worker) hasActiveConnection() bool { - activeConnection := false - w.sessionManager.ForEachLocalSession( - func(s session.Session) bool { - conns := s.GetLocalConnections() - for _, v := range conns { - if v.Status == pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CONNECTED { - activeConnection = true - return false - } - } - return true - }) - return activeConnection -} - -// Graceful shutdown sets the worker state to "shutdown" and will wait to return until there +// GracefulShutdownm sets the worker state to "shutdown" and will wait to return until there // are no longer any active connections. func (w *Worker) GracefulShutdown() error { const op = "worker.(Worker).GracefulShutdown" event.WriteSysEvent(w.baseContext, op, "worker entering graceful shutdown") w.operationalState.Store(server.ShutdownOperationalState) - // Wait for connections to drain - for w.hasActiveConnection() { + // As long as some status has been sent in the past, wait for 2 status + // updates to be sent since we've updated our operational state. + lastStatusTime := w.lastSuccessfulStatusTime() + if lastStatusTime != w.workerStartTime { + for i := 0; i < 2; i++ { + for { + if lastStatusTime != w.lastSuccessfulStatusTime() { + lastStatusTime = w.lastSuccessfulStatusTime() + break + } + time.Sleep(time.Millisecond * 250) + } + } + } + + // Wait for running proxy connections to drain + for proxy.ProxyState.CurrentProxiedConnections() > 0 { time.Sleep(time.Millisecond * 250) } event.WriteSysEvent(w.baseContext, op, "worker connections have drained") diff --git a/internal/server/options.go b/internal/server/options.go index 391c499d4d..2bc67c71b1 100644 --- a/internal/server/options.go +++ b/internal/server/options.go @@ -48,6 +48,7 @@ type options struct { withActiveWorkers bool withFeature version.Feature withDirectlyConnected bool + withWorkerPool []string } func getDefaultOptions() options { @@ -238,3 +239,10 @@ func WithDirectlyConnected(conn bool) Option { o.withDirectlyConnected = conn } } + +// WithWorkerPool provides a slice of worker ids. +func WithWorkerPool(workerIds []string) Option { + return func(o *options) { + o.withWorkerPool = workerIds + } +} diff --git a/internal/server/options_test.go b/internal/server/options_test.go index d816e5758a..d88efecff0 100644 --- a/internal/server/options_test.go +++ b/internal/server/options_test.go @@ -225,4 +225,12 @@ func Test_GetOpts(t *testing.T) { testOpts.withNewIdFunc = nil assert.Equal(t, opts, testOpts) }) + t.Run("WithWorkerPool", func(t *testing.T) { + opts := GetOpts(WithWorkerPool([]string{"1", "2", "3"})) + testOpts := getDefaultOptions() + testOpts.withWorkerPool = []string{"1", "2", "3"} + opts.withNewIdFunc = nil + testOpts.withNewIdFunc = nil + assert.Equal(t, opts, testOpts) + }) }