From fb1ea92a30a2888fbfb7c36ff30ed5f51e738e23 Mon Sep 17 00:00:00 2001 From: Todd Date: Tue, 30 May 2023 10:23:23 -0700 Subject: [PATCH] Pass both the control and data context into handle --- internal/daemon/worker/handler.go | 4 ++-- internal/daemon/worker/proxy/proxy.go | 4 ++-- internal/daemon/worker/proxy/proxy_conn_tracker.go | 5 ++--- .../daemon/worker/proxy/proxy_conn_tracker_test.go | 5 ++--- internal/daemon/worker/proxy/proxy_test.go | 4 ++-- internal/daemon/worker/proxy/tcp/tcp.go | 12 ++++++------ internal/daemon/worker/proxy/tcp/tcp_test.go | 7 ++++--- 7 files changed, 20 insertions(+), 21 deletions(-) diff --git a/internal/daemon/worker/handler.go b/internal/daemon/worker/handler.go index de3ef4829e..580328c414 100644 --- a/internal/daemon/worker/handler.go +++ b/internal/daemon/worker/handler.go @@ -294,7 +294,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "error getting decryption function") event.WriteError(ctx, op, err) } - runProxy, err := handleProxyFn(ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx, w.recorderManager) + runProxy, err := handleProxyFn(ctx, ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx, w.recorderManager) if err != nil { conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "unable to setup proxying") event.WriteError(ctx, op, err) @@ -321,7 +321,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa return } - runProxy(ctx) + runProxy() }, nil } diff --git a/internal/daemon/worker/proxy/proxy.go b/internal/daemon/worker/proxy/proxy.go index 428c5683bd..c42fa7be44 100644 --- a/internal/daemon/worker/proxy/proxy.go +++ b/internal/daemon/worker/proxy/proxy.go @@ -39,14 +39,14 @@ type DecryptFn func(ctx context.Context, from []byte, to proto.Message) error // ProxyConnFn is called after the call to ConnectConnection on the cluster. // ProxyConnFn blocks until the specific request that is being proxied is finished -type ProxyConnFn func(ctx context.Context) +type ProxyConnFn func() // Handler is the type that all proxies need to implement to be called by the worker // when a new client connection is created. If there is an error ProxyConnFn must // be nil. If there is no error ProxyConnFn must be set. When Handler has // returned, it is expected that the initial connection to the endpoint has been // established. -type Handler func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) +type Handler func(controlCtx context.Context, dataCtx context.Context, df DecryptFn, c net.Conn, pd *ProxyDialer, connId string, pb *anypb.Any, rm RecordingManager) (ProxyConnFn, error) func RegisterHandler(protocol string, handler Handler) error { _, loaded := handlers.LoadOrStore(protocol, handler) diff --git a/internal/daemon/worker/proxy/proxy_conn_tracker.go b/internal/daemon/worker/proxy/proxy_conn_tracker.go index a52b4c44bf..cb8c1d02a7 100644 --- a/internal/daemon/worker/proxy/proxy_conn_tracker.go +++ b/internal/daemon/worker/proxy/proxy_conn_tracker.go @@ -4,7 +4,6 @@ package proxy import ( - "context" "net/http" "sync/atomic" ) @@ -35,9 +34,9 @@ func ProxyHandlerCounter(h http.Handler) http.Handler { // proxyConnFnCounter wraps a ProxyState and keeps the proxyCount incremented // while it runs. func proxyConnFnCounter(fn ProxyConnFn) ProxyConnFn { - return func(ctx context.Context) { + return func() { ProxyState.proxyCount.Add(1) defer ProxyState.proxyCount.Add(-1) - fn(ctx) + fn() } } diff --git a/internal/daemon/worker/proxy/proxy_conn_tracker_test.go b/internal/daemon/worker/proxy/proxy_conn_tracker_test.go index ce4bdcb700..df099b4da6 100644 --- a/internal/daemon/worker/proxy/proxy_conn_tracker_test.go +++ b/internal/daemon/worker/proxy/proxy_conn_tracker_test.go @@ -4,7 +4,6 @@ package proxy import ( - "context" "net/http" "net/http/httptest" "testing" @@ -22,9 +21,9 @@ func TestProxyStateHelpers(t *testing.T) { t.Run("proxyConnFnCounter", func(t *testing.T) { assert.EqualValues(t, 0, ProxyState.CurrentProxiedConnections()) - proxyConnFnCounter(func(context.Context) { + proxyConnFnCounter(func() { assert.EqualValues(t, 1, ProxyState.CurrentProxiedConnections()) - })(context.Background()) + })() assert.EqualValues(t, 0, ProxyState.CurrentProxiedConnections()) }) diff --git a/internal/daemon/worker/proxy/proxy_test.go b/internal/daemon/worker/proxy/proxy_test.go index e0b33d059d..89b31ae479 100644 --- a/internal/daemon/worker/proxy/proxy_test.go +++ b/internal/daemon/worker/proxy/proxy_test.go @@ -17,7 +17,7 @@ import ( func TestRegisterHandler(t *testing.T) { assert, require := assert.New(t), require.New(t) - fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) { + fn := func(context.Context, context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) { return nil, nil } oldHandler := handlers @@ -40,7 +40,7 @@ func TestRegisterHandler(t *testing.T) { func TestAlwaysTcpGetHandler(t *testing.T) { assert, require := assert.New(t), require.New(t) - fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) { + fn := func(context.Context, context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) { return nil, nil } oldHandler := handlers diff --git a/internal/daemon/worker/proxy/tcp/tcp.go b/internal/daemon/worker/proxy/tcp/tcp.go index c4f0a17737..35a3bd589f 100644 --- a/internal/daemon/worker/proxy/tcp/tcp.go +++ b/internal/daemon/worker/proxy/tcp/tcp.go @@ -27,22 +27,22 @@ func init() { // handleProxy returns a ProxyConnFn which starts the copy between the // connections and blocks until an error (EOF on happy path) is received on // either connection. -func handleProxy(ctx context.Context, _ proxy.DecryptFn, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any, _ proxy.RecordingManager) (proxy.ProxyConnFn, error) { +func handleProxy(controlCtx context.Context, _ context.Context, _ proxy.DecryptFn, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any, _ proxy.RecordingManager) (proxy.ProxyConnFn, error) { const op = "tcp.HandleProxy" switch { case conn == nil: - return nil, errors.New(ctx, errors.InvalidParameter, op, "conn is nil") + return nil, errors.New(controlCtx, errors.InvalidParameter, op, "conn is nil") case out == nil: - return nil, errors.New(ctx, errors.InvalidParameter, op, "proxy dialer is nil") + return nil, errors.New(controlCtx, errors.InvalidParameter, op, "proxy dialer is nil") case len(connId) == 0: - return nil, errors.New(ctx, errors.InvalidParameter, op, "connection id is empty") + return nil, errors.New(controlCtx, errors.InvalidParameter, op, "connection id is empty") } - remoteConn, err := out.Dial(ctx) + remoteConn, err := out.Dial(controlCtx) if err != nil { return nil, err } - return func(ctx context.Context) { + return func() { connWg := new(sync.WaitGroup) connWg.Add(2) go func() { diff --git a/internal/daemon/worker/proxy/tcp/tcp_test.go b/internal/daemon/worker/proxy/tcp/tcp_test.go index 1a819e5d75..15c77db038 100644 --- a/internal/daemon/worker/proxy/tcp/tcp_test.go +++ b/internal/daemon/worker/proxy/tcp/tcp_test.go @@ -88,7 +88,8 @@ func TestHandleProxy_Errors(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - fn, err := handleProxy(context.Background(), nil, tc.conn, tc.dialer, tc.connId, tc.protocolCtx, nil) + ctx := context.Background() + fn, err := handleProxy(ctx, ctx, nil, tc.conn, tc.dialer, tc.connId, tc.protocolCtx, nil) if tc.wantError { assert.Error(t, err) assert.Nil(t, fn) @@ -166,14 +167,14 @@ func TestHandleTcpProxyV1(t *testing.T) { conn := websocket.NetConn(ctx, proxyConn, websocket.MessageBinary) go func() { - fn, err := handleProxy(ctx, nil, conn, tDial, resp.GetConnectionId(), resp.GetProtocolContext(), nil) + fn, err := handleProxy(ctx, context.Background(), nil, conn, tDial, resp.GetConnectionId(), resp.GetProtocolContext(), nil) t.Cleanup(func() { // Use of the t.Cleanup is so we can check the state of the returned // error since it isn't valid to call `t.FailNow()` from a goroutine. // https://pkg.go.dev/testing#T.FailNow require.NoError(err) }) - fn(context.Background()) + fn() }() // wait for HandleTcpProxyV1 to dial endpoint