From beb97ac9f28f8d68a77f03e35bb091eb2a819aed Mon Sep 17 00:00:00 2001 From: Irena Rindos Date: Mon, 9 Jan 2023 13:43:02 -0500 Subject: [PATCH] SSH Proxy support (#2736) * initial commit --- internal/daemon/worker/handler.go | 4 ++-- internal/daemon/worker/proxy/options.go | 14 ++++++++++++++ internal/daemon/worker/proxy/options_test.go | 16 ++++++++++++++++ internal/daemon/worker/proxy/proxy.go | 17 ++++++++++++----- internal/daemon/worker/proxy/proxy_test.go | 15 ++++++--------- internal/daemon/worker/proxy/proxydialer.go | 10 +++++----- .../daemon/worker/proxy/proxydialer_test.go | 8 +++++--- internal/daemon/worker/proxy/tcp/tcp.go | 16 +++++++--------- internal/daemon/worker/proxy/tcp/tcp_test.go | 6 +++--- 9 files changed, 70 insertions(+), 36 deletions(-) diff --git a/internal/daemon/worker/handler.go b/internal/daemon/worker/handler.go index 1da3342ccb..1d8f367499 100644 --- a/internal/daemon/worker/handler.go +++ b/internal/daemon/worker/handler.go @@ -273,7 +273,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa } // Verify the protocol has a supported proxy before calling RequestAuthorizeConnection - handleProxyFn, err := proxyHandlers.GetHandler(endpointUrl.Scheme) + handleProxyFn, err := proxyHandlers.GetHandler(workerId, ci.GetProtocolContext()) if err != nil { conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "unable to get proxy handler") event.WriteError(ctx, op, err) @@ -306,7 +306,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa return } - runProxy() + runProxy(ctx) }, nil } diff --git a/internal/daemon/worker/proxy/options.go b/internal/daemon/worker/proxy/options.go index cad7adf380..a46813a600 100644 --- a/internal/daemon/worker/proxy/options.go +++ b/internal/daemon/worker/proxy/options.go @@ -1,6 +1,8 @@ package proxy import ( + "net" + serverpb "github.com/hashicorp/boundary/internal/gen/controller/servers/services" ) @@ -19,11 +21,13 @@ func GetOpts(opt ...Option) Options { // Options = how options are represented type Options struct { WithInjectedApplicationCredentials []*serverpb.Credential + WithPostConnectionHook func(net.Conn) } func getDefaultOptions() Options { return Options{ WithInjectedApplicationCredentials: nil, + WithPostConnectionHook: nil, } } @@ -34,3 +38,13 @@ func WithInjectedApplicationCredentials(creds []*serverpb.Credential) Option { o.WithInjectedApplicationCredentials = creds } } + +// WithPostConnectionHook provides a hook function to be called after a +// connection is established in a dialFunction. When a dialer accepts +// WithPostConnectionHook the passed in function should be called prior to any +// other blocking call. +func WithPostConnectionHook(fn func(net.Conn)) Option { + return func(o *Options) { + o.WithPostConnectionHook = fn + } +} diff --git a/internal/daemon/worker/proxy/options_test.go b/internal/daemon/worker/proxy/options_test.go index a4733e2af5..160202426d 100644 --- a/internal/daemon/worker/proxy/options_test.go +++ b/internal/daemon/worker/proxy/options_test.go @@ -1,6 +1,9 @@ package proxy import ( + "net" + "reflect" + "runtime" "testing" serverpb "github.com/hashicorp/boundary/internal/gen/controller/servers/services" @@ -26,4 +29,17 @@ func Test_GetOpts(t *testing.T) { testOpts.WithInjectedApplicationCredentials = []*serverpb.Credential{c} assert.Equal(opts, testOpts) }) + + t.Run("WithPostConnectionHook", func(t *testing.T) { + assert := assert.New(t) + testFn := func(net.Conn) {} + opts := GetOpts(WithPostConnectionHook(testFn)) + testOpts := getDefaultOptions() + assert.NotEqual(opts, testOpts) + testOpts.WithPostConnectionHook = testFn + assert.Equal( + runtime.FuncForPC(reflect.ValueOf(opts.WithPostConnectionHook).Pointer()).Name(), + runtime.FuncForPC(reflect.ValueOf(testOpts.WithPostConnectionHook).Pointer()).Name(), + ) + }) } diff --git a/internal/daemon/worker/proxy/proxy.go b/internal/daemon/worker/proxy/proxy.go index ef58f8d469..a208d627a3 100644 --- a/internal/daemon/worker/proxy/proxy.go +++ b/internal/daemon/worker/proxy/proxy.go @@ -6,10 +6,13 @@ import ( "net" "sync" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" ) var ( + TcpHandlerName = "tcp" + // handlers is the map of registered handlers handlers sync.Map @@ -18,11 +21,16 @@ var ( // ErrProtocolAlreadyRegistered specifies the provided protocol has already been registered ErrProtocolAlreadyRegistered = errors.New("proxy: protocol already registered") + + // GetHandler returns the handler registered for the provided worker and + // protocolContext. If a protocol cannot be determined or the protocol is + // not registered nil, ErrUnknownProtocol is returned. + GetHandler = tcpOnly ) // ProxyConnFn is called after the call to ConnectConnection on the cluster. // ProxyConnFn blocks until the specific request that is being proxied is finished -type ProxyConnFn func() +type ProxyConnFn func(ctx context.Context) // Handler is the type that all proxies need to implement to be called by the worker // when a new client connection is created. If there is an error ProxyConnFn must @@ -39,10 +47,9 @@ func RegisterHandler(protocol string, handler Handler) error { return nil } -// GetHandler returns the handler registered for the provided protocol. If the protocol -// is not registered nil and ErrUnknownProtocol is returned. -func GetHandler(protocol string) (Handler, error) { - handler, ok := handlers.Load(protocol) +// tcpOnly returns only the TCP protocol. +func tcpOnly(string, proto.Message) (Handler, error) { + handler, ok := handlers.Load(TcpHandlerName) if !ok { return nil, ErrUnknownProtocol } diff --git a/internal/daemon/worker/proxy/proxy_test.go b/internal/daemon/worker/proxy/proxy_test.go index 3f551fb305..04522a327d 100644 --- a/internal/daemon/worker/proxy/proxy_test.go +++ b/internal/daemon/worker/proxy/proxy_test.go @@ -35,9 +35,8 @@ func TestRegisterHandler(t *testing.T) { require.NoError(err) } -func TestGetHandler(t *testing.T) { +func TestAlwaysTcpGetHandler(t *testing.T) { assert, require := assert.New(t), require.New(t) - fn := func(context.Context, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) { return nil, nil } @@ -47,15 +46,13 @@ func TestGetHandler(t *testing.T) { }) handlers = sync.Map{} - err := RegisterHandler("fn", fn) - require.NoError(err) - - gotFn, err := GetHandler("fake") + _, err := tcpOnly("wid", nil) require.Error(err) assert.ErrorIs(err, ErrUnknownProtocol) - assert.Nil(gotFn) - gotFn, err = GetHandler("fn") + require.NoError(RegisterHandler("tcp", fn)) + + handler, err := tcpOnly("wid", nil) require.NoError(err) - assert.NotNil(gotFn) + require.NotNil(handler) } diff --git a/internal/daemon/worker/proxy/proxydialer.go b/internal/daemon/worker/proxy/proxydialer.go index fb5899903f..9428a6215d 100644 --- a/internal/daemon/worker/proxy/proxydialer.go +++ b/internal/daemon/worker/proxy/proxydialer.go @@ -20,7 +20,7 @@ func directDialer(ctx context.Context, endpoint string, _ string, _ proto.Messag if len(endpoint) == 0 { return nil, errors.New(ctx, errors.InvalidParameter, op, "endpoint is empty") } - d, err := NewProxyDialer(ctx, func() (net.Conn, error) { + d, err := NewProxyDialer(ctx, func(...Option) (net.Conn, error) { remoteConn, err := net.Dial("tcp", endpoint) if err != nil { return nil, errors.Wrap(ctx, err, op) @@ -52,12 +52,12 @@ func (p *proxyAddr) Port() uint32 { // ProxyDialer dials downstream to eventually get to the target host. type ProxyDialer struct { - dialFn func() (net.Conn, error) + dialFn func(...Option) (net.Conn, error) latestAddr atomic.Pointer[proxyAddr] } // Returns a new proxy dialer using the provided function to get the net.Conn. -func NewProxyDialer(ctx context.Context, df func() (net.Conn, error)) (*ProxyDialer, error) { +func NewProxyDialer(ctx context.Context, df func(...Option) (net.Conn, error)) (*ProxyDialer, error) { const op = "proxy.NewProxyDialer" if df == nil { return nil, errors.New(ctx, errors.InvalidParameter, op, "dialing function is nil") @@ -86,9 +86,9 @@ type portAndIpGetter interface { // Dial uses the provided dial function to get a net.Conn and record its // net.Addr information. The returned net.Addr should contain the information // for the endpoint that is being proxied to. -func (d *ProxyDialer) Dial(ctx context.Context) (net.Conn, error) { +func (d *ProxyDialer) Dial(ctx context.Context, opt ...Option) (net.Conn, error) { const op = "proxy.(*ProxyDialer).Dial" - c, err := d.dialFn() + c, err := d.dialFn(opt...) if err != nil { return nil, err } diff --git a/internal/daemon/worker/proxy/proxydialer_test.go b/internal/daemon/worker/proxy/proxydialer_test.go index 348cb1ae9e..9822b67959 100644 --- a/internal/daemon/worker/proxy/proxydialer_test.go +++ b/internal/daemon/worker/proxy/proxydialer_test.go @@ -15,7 +15,7 @@ func TestNewProxyDialer(t *testing.T) { assert.Error(t, err) assert.Nil(t, d) - d, err = NewProxyDialer(context.Background(), func() (net.Conn, error) { + d, err = NewProxyDialer(context.Background(), func(...Option) (net.Conn, error) { c, _ := net.Pipe() return c, nil }) @@ -34,9 +34,10 @@ func TestProxyDialer(t *testing.T) { t.Run("Dial error", func(t *testing.T) { expectedErr := errors.New("test error") - d, err := NewProxyDialer(ctx, func() (net.Conn, error) { + d, err := NewProxyDialer(ctx, func(...Option) (net.Conn, error) { return nil, expectedErr }) + require.NoError(t, err) assert.Nil(t, d.LastConnectionAddr()) badC, err := d.Dial(ctx) require.Error(t, err) @@ -45,9 +46,10 @@ func TestProxyDialer(t *testing.T) { }) t.Run("Successful Dial", func(t *testing.T) { - d, err := NewProxyDialer(ctx, func() (net.Conn, error) { + d, err := NewProxyDialer(ctx, func(...Option) (net.Conn, error) { return net.Dial("tcp", l.Addr().String()) }) + require.NoError(t, err) assert.Nil(t, d.LastConnectionAddr()) c, err := d.Dial(ctx) require.NoError(t, err) diff --git a/internal/daemon/worker/proxy/tcp/tcp.go b/internal/daemon/worker/proxy/tcp/tcp.go index 96117c10ff..f83505d0e8 100644 --- a/internal/daemon/worker/proxy/tcp/tcp.go +++ b/internal/daemon/worker/proxy/tcp/tcp.go @@ -12,20 +12,18 @@ import ( ) func init() { - err := proxy.RegisterHandler("tcp", handleProxy) + err := proxy.RegisterHandler(proxy.TcpHandlerName, handleProxy) if err != nil { panic(err) } } -// handleProxy creates a tcp proxy between the incoming websocket conn and the -// connection it creates with the remote endpoint. handleTcpProxyV1 sets the connectionId -// as connected in the repository. +// handleProxy creates a tcp proxy between the incoming conn and the +// connection created by the ProxyDialer. // -// handleProxy blocks until an error (EOF on happy path) is received on either -// connection. -// -// All options are ignored. +// handleProxy returns a ProxyConnFn which starts the copy between the +// connections and blocks until an error (EOF on happy path) is received on +// either connection. func handleProxy(ctx context.Context, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any) (proxy.ProxyConnFn, error) { const op = "tcp.HandleProxy" switch { @@ -41,7 +39,7 @@ func handleProxy(ctx context.Context, conn net.Conn, out *proxy.ProxyDialer, con return nil, err } - return func() { + return func(ctx context.Context) { connWg := new(sync.WaitGroup) connWg.Add(2) go func() { diff --git a/internal/daemon/worker/proxy/tcp/tcp_test.go b/internal/daemon/worker/proxy/tcp/tcp_test.go index b25ef634f3..b1fb2b5688 100644 --- a/internal/daemon/worker/proxy/tcp/tcp_test.go +++ b/internal/daemon/worker/proxy/tcp/tcp_test.go @@ -29,7 +29,7 @@ func TestHandleProxy_Errors(t *testing.T) { t.Cleanup(func() { l.Close() }) - dialer, err := proxy.NewProxyDialer(context.Background(), func() (net.Conn, error) { + dialer, err := proxy.NewProxyDialer(context.Background(), func(...proxy.Option) (net.Conn, error) { return net.Dial("tcp", l.Addr().String()) }) require.NoError(t, err) @@ -156,7 +156,7 @@ func TestHandleTcpProxyV1(t *testing.T) { resp, _, err := s.RequestAuthorizeConnection(ctx, "workerid", connCancelFn) require.NoError(err) - tDial, err := proxy.NewProxyDialer(ctx, func() (net.Conn, error) { + tDial, err := proxy.NewProxyDialer(ctx, func(...proxy.Option) (net.Conn, error) { return net.Dial("tcp", l.Addr().String()) }) require.NoError(err) @@ -170,7 +170,7 @@ func TestHandleTcpProxyV1(t *testing.T) { // https://pkg.go.dev/testing#T.FailNow require.NoError(err) }) - fn() + fn(context.Background()) }() // wait for HandleTcpProxyV1 to dial endpoint