From bd2bca987d481e6c2c6b53b833c6b275aa424628 Mon Sep 17 00:00:00 2001 From: Todd Date: Tue, 28 Mar 2023 14:24:05 -0700 Subject: [PATCH] Add the first iteration of the recorder cache and wire it in --- internal/daemon/worker/handler.go | 2 +- internal/daemon/worker/proxy/proxy.go | 5 ++++- internal/daemon/worker/proxy/proxy_test.go | 4 ++-- internal/daemon/worker/proxy/tcp/tcp.go | 2 +- internal/daemon/worker/proxy/tcp/tcp_test.go | 4 ++-- internal/daemon/worker/worker.go | 12 ++++++++++++ internal/errors/code.go | 1 + 7 files changed, 23 insertions(+), 7 deletions(-) diff --git a/internal/daemon/worker/handler.go b/internal/daemon/worker/handler.go index b68b51a7eb..e0d0af984f 100644 --- a/internal/daemon/worker/handler.go +++ b/internal/daemon/worker/handler.go @@ -294,7 +294,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "error getting decryption function") event.WriteError(ctx, op, err) } - runProxy, err := handleProxyFn(ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx) + runProxy, err := handleProxyFn(ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx, w.recorderCache) if err != nil { conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "unable to setup proxying") event.WriteError(ctx, op, err) diff --git a/internal/daemon/worker/proxy/proxy.go b/internal/daemon/worker/proxy/proxy.go index 5b698a98e3..428c5683bd 100644 --- a/internal/daemon/worker/proxy/proxy.go +++ b/internal/daemon/worker/proxy/proxy.go @@ -31,6 +31,9 @@ var ( GetHandler = tcpOnly ) +// RecordingManager allows a handler for a protocol that supports recording. +type RecordingManager any + // DecryptFn decrypts the provided bytes into a proto.Message type DecryptFn func(ctx context.Context, from []byte, to proto.Message) error @@ -43,7 +46,7 @@ type ProxyConnFn func(ctx context.Context) // be nil. If there is no error ProxyConnFn must be set. When Handler has // returned, it is expected that the initial connection to the endpoint has been // established. -type Handler func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) +type Handler func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) func RegisterHandler(protocol string, handler Handler) error { _, loaded := handlers.LoadOrStore(protocol, handler) diff --git a/internal/daemon/worker/proxy/proxy_test.go b/internal/daemon/worker/proxy/proxy_test.go index 84d0d04fee..e0b33d059d 100644 --- a/internal/daemon/worker/proxy/proxy_test.go +++ b/internal/daemon/worker/proxy/proxy_test.go @@ -17,7 +17,7 @@ import ( func TestRegisterHandler(t *testing.T) { assert, require := assert.New(t), require.New(t) - fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) { + fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) { return nil, nil } oldHandler := handlers @@ -40,7 +40,7 @@ func TestRegisterHandler(t *testing.T) { func TestAlwaysTcpGetHandler(t *testing.T) { assert, require := assert.New(t), require.New(t) - fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) { + fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) { return nil, nil } oldHandler := handlers diff --git a/internal/daemon/worker/proxy/tcp/tcp.go b/internal/daemon/worker/proxy/tcp/tcp.go index 81e638123f..c4f0a17737 100644 --- a/internal/daemon/worker/proxy/tcp/tcp.go +++ b/internal/daemon/worker/proxy/tcp/tcp.go @@ -27,7 +27,7 @@ func init() { // handleProxy returns a ProxyConnFn which starts the copy between the // connections and blocks until an error (EOF on happy path) is received on // either connection. -func handleProxy(ctx context.Context, _ proxy.DecryptFn, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any) (proxy.ProxyConnFn, error) { +func handleProxy(ctx context.Context, _ proxy.DecryptFn, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any, _ proxy.RecordingManager) (proxy.ProxyConnFn, error) { const op = "tcp.HandleProxy" switch { case conn == nil: diff --git a/internal/daemon/worker/proxy/tcp/tcp_test.go b/internal/daemon/worker/proxy/tcp/tcp_test.go index 07340cd7bf..1a819e5d75 100644 --- a/internal/daemon/worker/proxy/tcp/tcp_test.go +++ b/internal/daemon/worker/proxy/tcp/tcp_test.go @@ -88,7 +88,7 @@ func TestHandleProxy_Errors(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - fn, err := handleProxy(context.Background(), nil, tc.conn, tc.dialer, tc.connId, tc.protocolCtx) + fn, err := handleProxy(context.Background(), nil, tc.conn, tc.dialer, tc.connId, tc.protocolCtx, nil) if tc.wantError { assert.Error(t, err) assert.Nil(t, fn) @@ -166,7 +166,7 @@ func TestHandleTcpProxyV1(t *testing.T) { conn := websocket.NetConn(ctx, proxyConn, websocket.MessageBinary) go func() { - fn, err := handleProxy(ctx, nil, conn, tDial, resp.GetConnectionId(), resp.GetProtocolContext()) + fn, err := handleProxy(ctx, nil, conn, tDial, resp.GetConnectionId(), resp.GetProtocolContext(), nil) t.Cleanup(func() { // Use of the t.Cleanup is so we can check the state of the returned // error since it isn't valid to call `t.FailNow()` from a goroutine. diff --git a/internal/daemon/worker/worker.go b/internal/daemon/worker/worker.go index 7913491552..8e2a1967a5 100644 --- a/internal/daemon/worker/worker.go +++ b/internal/daemon/worker/worker.go @@ -69,10 +69,16 @@ type downstreamers interface { RootId() string } +// recorderCache updates the status updates with relevant recording +// information +type recorderCache any + // reverseConnReceiverFactory provides a simple factory which a Worker can use to // create its reverseConnReceiver var reverseConnReceiverFactory func() reverseConnReceiver +var recorderCacheFactory func() recorderCache + var initializeReverseGrpcClientCollectors = noopInitializePromCollectors func noopInitializePromCollectors(r prometheus.Registerer) {} @@ -103,6 +109,8 @@ type Worker struct { sessionManager session.Manager + recorderCache recorderCache + controllerStatusConn *atomic.Value everAuthenticated *ua.Uint32 lastStatusSuccess *atomic.Value @@ -182,6 +190,10 @@ func New(conf *Config) (*Worker, error) { w.downstreamReceiver = reverseConnReceiverFactory() } + if recorderCacheFactory != nil { + w.recorderCache = recorderCacheFactory() + } + w.lastStatusSuccess.Store((*LastStatusInformation)(nil)) scheme := strconv.FormatInt(time.Now().UnixNano(), 36) controllerResolver := manual.NewBuilderWithScheme(scheme) diff --git a/internal/errors/code.go b/internal/errors/code.go index 8a557384e5..e93204ce61 100644 --- a/internal/errors/code.go +++ b/internal/errors/code.go @@ -84,6 +84,7 @@ const ( // client and server error codes Unauthorized Code = 401 // Unauthorized represents the operation is unauthorized Forbidden Code = 403 // Forbidden represents the operation is forbidden + NotFound Code = 404 // NotFound represents an operation which is unable to find the requested item. Conflict Code = 409 // Conflict represents the operation failed due to failed pre-condition or was aborted. Internal Code = 500 // InternalError represents the system encountered an unexpected condition.