From eaf31ef17a0854d41d7d200cfd0a8374e60100ae Mon Sep 17 00:00:00 2001 From: Irena Rindos Date: Mon, 11 Aug 2025 11:34:55 -0400 Subject: [PATCH] feat(rdp): perform basic settings exchange and test the proxy fn (#1657) ----------- Co-authored-by: Danielle Miu <29378233+DanielleMiu@users.noreply.github.com> --- internal/daemon/worker/handler.go | 2 +- internal/daemon/worker/proxy/options.go | 26 ++++++++++++++++++++ internal/daemon/worker/proxy/options_test.go | 18 ++++++++++++++ internal/daemon/worker/proxy/proxy.go | 2 +- internal/daemon/worker/proxy/proxy_test.go | 4 +-- internal/daemon/worker/proxy/tcp/tcp.go | 2 +- 6 files changed, 49 insertions(+), 5 deletions(-) diff --git a/internal/daemon/worker/handler.go b/internal/daemon/worker/handler.go index d573f26065..a0d0691166 100644 --- a/internal/daemon/worker/handler.go +++ b/internal/daemon/worker/handler.go @@ -280,7 +280,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "error getting decryption function") event.WriteError(ctx, op, err) } - runProxy, err := handleProxyFn(ctx, ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx, w.recorderManager) + runProxy, err := handleProxyFn(ctx, ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx, w.recorderManager, proxyHandlers.WithLogger(w.logger)) if err != nil { conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "unable to setup proxying") event.WriteError(ctx, op, err) diff --git a/internal/daemon/worker/proxy/options.go b/internal/daemon/worker/proxy/options.go index 3c01f23061..f029fa19c1 100644 --- a/internal/daemon/worker/proxy/options.go +++ b/internal/daemon/worker/proxy/options.go @@ -7,6 +7,7 @@ import ( "net" serverpb "github.com/hashicorp/boundary/internal/gen/controller/servers/services" + "github.com/hashicorp/go-hclog" ) // Option - how Options are passed as arguments. @@ -26,12 +27,16 @@ type Options struct { WithInjectedApplicationCredentials []*serverpb.Credential WithPostConnectionHook func(net.Conn) WithDnsServerAddress string + WithTestKdcAddress string + WithTestKerberosServerHostname string + WithLogger hclog.Logger } func getDefaultOptions() Options { return Options{ WithInjectedApplicationCredentials: nil, WithPostConnectionHook: nil, + WithLogger: hclog.NewNullLogger(), } } @@ -61,3 +66,24 @@ func WithDnsServerAddress(with string) Option { o.WithDnsServerAddress = with } } + +// WithTestKdcAddress allows specifying a test KDC address to use for testing +func WithTestKdcAddress(with string) Option { + return func(o *Options) { + o.WithTestKdcAddress = with + } +} + +// WithTestKerberosServerHostname allows specifying a test Kerberos server +func WithTestKerberosServerHostname(with string) Option { + return func(o *Options) { + o.WithTestKerberosServerHostname = with + } +} + +// WithLogger allows specifying a logger to be used during session proxy +func WithLogger(l hclog.Logger) Option { + return func(o *Options) { + o.WithLogger = l + } +} diff --git a/internal/daemon/worker/proxy/options_test.go b/internal/daemon/worker/proxy/options_test.go index 7514c3a92d..b547b08c2a 100644 --- a/internal/daemon/worker/proxy/options_test.go +++ b/internal/daemon/worker/proxy/options_test.go @@ -45,4 +45,22 @@ func Test_GetOpts(t *testing.T) { runtime.FuncForPC(reflect.ValueOf(testOpts.WithPostConnectionHook).Pointer()).Name(), ) }) + t.Run("WithTestKdcAdress", func(t *testing.T) { + assert := assert.New(t) + testKdcAddress := "test-kdc-address" + opts := GetOpts(WithTestKdcAddress(testKdcAddress)) + testOpts := getDefaultOptions() + assert.NotEqual(opts, testOpts) + testOpts.WithTestKdcAddress = testKdcAddress + assert.Equal(opts, testOpts) + }) + t.Run("WithTestKerberosServerHostname", func(t *testing.T) { + assert := assert.New(t) + testKerberosServerHostname := "test-kerberos-server-hostname" + opts := GetOpts(WithTestKerberosServerHostname(testKerberosServerHostname)) + testOpts := getDefaultOptions() + assert.NotEqual(opts, testOpts) + testOpts.WithTestKerberosServerHostname = testKerberosServerHostname + assert.Equal(opts, testOpts) + }) } diff --git a/internal/daemon/worker/proxy/proxy.go b/internal/daemon/worker/proxy/proxy.go index 77dbca71c4..6d2ae55433 100644 --- a/internal/daemon/worker/proxy/proxy.go +++ b/internal/daemon/worker/proxy/proxy.go @@ -46,7 +46,7 @@ type ProxyConnFn func() // be nil. If there is no error ProxyConnFn must be set. When Handler has // returned, it is expected that the initial connection to the endpoint has been // established. -type Handler func(controlCtx context.Context, dataCtx context.Context, df DecryptFn, c net.Conn, pd *ProxyDialer, connId string, pb *anypb.Any, rm RecordingManager) (ProxyConnFn, error) +type Handler func(controlCtx context.Context, dataCtx context.Context, df DecryptFn, c net.Conn, pd *ProxyDialer, connId string, pb *anypb.Any, rm RecordingManager, opt ...Option) (ProxyConnFn, error) func RegisterHandler(protocol string, handler Handler) error { _, loaded := handlers.LoadOrStore(protocol, handler) diff --git a/internal/daemon/worker/proxy/proxy_test.go b/internal/daemon/worker/proxy/proxy_test.go index 5de35018fa..891a23c643 100644 --- a/internal/daemon/worker/proxy/proxy_test.go +++ b/internal/daemon/worker/proxy/proxy_test.go @@ -17,7 +17,7 @@ import ( func TestRegisterHandler(t *testing.T) { assert, require := assert.New(t), require.New(t) - fn := func(context.Context, context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) { + fn := func(context.Context, context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager, ...Option) (ProxyConnFn, error) { return nil, nil } oldHandler := handlers @@ -40,7 +40,7 @@ func TestRegisterHandler(t *testing.T) { func TestAlwaysTcpGetHandler(t *testing.T) { assert, require := assert.New(t), require.New(t) - fn := func(context.Context, context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) { + fn := func(context.Context, context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager, ...Option) (ProxyConnFn, error) { return nil, nil } oldHandler := handlers diff --git a/internal/daemon/worker/proxy/tcp/tcp.go b/internal/daemon/worker/proxy/tcp/tcp.go index 3a4ba22a2f..414c9263dc 100644 --- a/internal/daemon/worker/proxy/tcp/tcp.go +++ b/internal/daemon/worker/proxy/tcp/tcp.go @@ -27,7 +27,7 @@ func init() { // handleProxy returns a ProxyConnFn which starts the copy between the // connections and blocks until an error (EOF on happy path) is received on // either connection. -func handleProxy(controlCtx context.Context, _ context.Context, _ proxy.DecryptFn, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any, _ proxy.RecordingManager) (proxy.ProxyConnFn, error) { +func handleProxy(controlCtx context.Context, _ context.Context, _ proxy.DecryptFn, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any, _ proxy.RecordingManager, _ ...proxy.Option) (proxy.ProxyConnFn, error) { const op = "tcp.HandleProxy" switch { case conn == nil: