diff --git a/internal/servers/worker/handler.go b/internal/servers/worker/handler.go index e02f5d249b..1d74f3ff9c 100644 --- a/internal/servers/worker/handler.go +++ b/internal/servers/worker/handler.go @@ -232,7 +232,12 @@ func (w *Worker) handleProxy() http.HandlerFunc { return } - handleProxyFn(connCtx, conf) + if err = handleProxyFn(connCtx, conf); err != nil { + event.WriteError(ctx, op, err, event.WithInfoMsg("error handling proxy", "session_id", sessionId, "endpoint", endpoint)) + if err = conn.Close(websocket.StatusInternalError, "unable to establish proxy"); err != nil { + event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection")) + } + } } } diff --git a/internal/servers/worker/proxy/proxy.go b/internal/servers/worker/proxy/proxy.go index 4fff22c82b..4c2e074f94 100644 --- a/internal/servers/worker/proxy/proxy.go +++ b/internal/servers/worker/proxy/proxy.go @@ -47,7 +47,7 @@ func (c Config) Validate() error { // Handler is the type that all proxies need to implement to be called by the worker // when a new client connection is created. -type Handler func(ctx context.Context, config Config, opt ...Option) +type Handler func(ctx context.Context, config Config, opt ...Option) error var ( // handlers is the map of registered handlers diff --git a/internal/servers/worker/proxy/proxy_test.go b/internal/servers/worker/proxy/proxy_test.go index 100e3fcbc3..42081cdea9 100644 --- a/internal/servers/worker/proxy/proxy_test.go +++ b/internal/servers/worker/proxy/proxy_test.go @@ -136,7 +136,7 @@ func TestRegisterHandler(t *testing.T) { t.Parallel() assert, require := assert.New(t), require.New(t) - fn := func(context.Context, Config, ...Option) { return } + fn := func(context.Context, Config, ...Option) error { return nil } err := RegisterHandler("protocol", fn) require.NoError(err) @@ -154,7 +154,7 @@ func TestGetHandler(t *testing.T) { t.Parallel() assert, require := assert.New(t), require.New(t) - fn := func(context.Context, Config, ...Option) {} + fn := func(context.Context, Config, ...Option) error { return nil } err := RegisterHandler("fn", fn) require.NoError(err) diff --git a/internal/servers/worker/proxy/tcp/tcp.go b/internal/servers/worker/proxy/tcp/tcp.go index b07e779103..57f3c71edc 100644 --- a/internal/servers/worker/proxy/tcp/tcp.go +++ b/internal/servers/worker/proxy/tcp/tcp.go @@ -2,63 +2,45 @@ package tcp import ( "context" + "fmt" "io" "net" "net/url" "sync" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" - "github.com/hashicorp/boundary/internal/observability/event" "github.com/hashicorp/boundary/internal/servers/worker/proxy" "github.com/hashicorp/boundary/internal/servers/worker/session" "nhooyr.io/websocket" ) func init() { - err := proxy.RegisterHandler("tcp", handleTcpProxyV1) + err := proxy.RegisterHandler("tcp", handleProxy) if err != nil { panic(err) } } -// handleTcpProxyV1 creates a tcp proxy between the incoming websocket conn and the +// handleProxy creates a tcp proxy between the incoming websocket conn and the // connection it creates with the remote endpoint. handleTcpProxyV1 sets the connectionId // as connected in the repository. // -// handleTcpProxyV1 blocks until an error (EOF on happy path) is received on either +// handleProxy blocks until an error (EOF on happy path) is received on either // connection. // // All options are ignored. -func handleTcpProxyV1(ctx context.Context, conf proxy.Config, _ ...proxy.Option) { - const op = "tcp.HandleTcpProxyV1" - si := conf.SessionInfo - si.RLock() - sessionId := si.LookupSessionResponse.GetAuthorization().GetSessionId() - si.RUnlock() - +func handleProxy(ctx context.Context, conf proxy.Config, _ ...proxy.Option) error { conn := conf.ClientConn sessionUrl, err := url.Parse(conf.RemoteEndpoint) if err != nil { - event.WriteError(ctx, op, err, event.WithInfoMsg("error parsing endpoint information", "session_id", sessionId, "endpoint", conf.RemoteEndpoint)) - if err = conn.Close(websocket.StatusInternalError, "cannot parse endpoint url"); err != nil { - event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection")) - } - return + return fmt.Errorf("error parsing endpoint information: %w", err) } if sessionUrl.Scheme != "tcp" { - event.WriteError(ctx, op, err, event.WithInfo("session_id", sessionId, "endpoint", conf.RemoteEndpoint)) - if err = conn.Close(websocket.StatusInternalError, "invalid scheme for type"); err != nil { - event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection")) - } - return + return fmt.Errorf("invalid scheme for tcp proxy: %v", sessionUrl.Scheme) } remoteConn, err := net.Dial("tcp", sessionUrl.Host) if err != nil { - event.WriteError(ctx, op, err, event.WithInfoMsg("error dialing endpoint", "endpoint", conf.RemoteEndpoint)) - if err = conn.Close(websocket.StatusInternalError, "endpoint dialing failed"); err != nil { - event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection")) - } - return + return fmt.Errorf("error dialing endpoint: %w", err) } // Assert this for better Go 1.11 splice support tcpRemoteConn := remoteConn.(*net.TCPConn) @@ -75,15 +57,13 @@ func handleTcpProxyV1(ctx context.Context, conf proxy.Config, _ ...proxy.Option) connStatus, err := session.ConnectConnection(ctx, conf.SessionClient, connectionInfo) if err != nil { - event.WriteError(ctx, op, err, event.WithInfoMsg("error marking connection as connected")) - if err = conn.Close(websocket.StatusInternalError, "failed to mark connection as connected"); err != nil { - event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection")) - } - return + return fmt.Errorf("error marking connection as connected: %w", err) } - si.Lock() - si.ConnInfoMap[conf.ConnectionId].Status = connStatus - si.Unlock() + + // Update connection info to set connection status + conf.SessionInfo.Lock() + conf.SessionInfo.ConnInfoMap[conf.ConnectionId].Status = connStatus + conf.SessionInfo.Unlock() // Get a wrapped net.Conn so we can use io.Copy netConn := websocket.NetConn(ctx, conn, websocket.MessageBinary) @@ -103,4 +83,5 @@ func handleTcpProxyV1(ctx context.Context, conf proxy.Config, _ ...proxy.Option) _ = netConn.Close() }() connWg.Wait() + return nil } diff --git a/internal/servers/worker/proxy/tcp/tcp_test.go b/internal/servers/worker/proxy/tcp/tcp_test.go index 457be83b37..0a365ba6b1 100644 --- a/internal/servers/worker/proxy/tcp/tcp_test.go +++ b/internal/servers/worker/proxy/tcp/tcp_test.go @@ -70,7 +70,10 @@ func TestHandleTcpProxyV1(t *testing.T) { ConnectionId: "mock-connection", } - go handleTcpProxyV1(ctx, conf) + go func() { + err = handleProxy(ctx, conf) + require.NoError(err) + }() // wait for HandleTcpProxyV1 to dial endpoint <-ready