diff --git a/internal/daemon/cluster/handlers/worker_service.go b/internal/daemon/cluster/handlers/worker_service.go index fd6c9c6e73..aa01c83f4c 100644 --- a/internal/daemon/cluster/handlers/worker_service.go +++ b/internal/daemon/cluster/handlers/worker_service.go @@ -45,7 +45,10 @@ var ( _ pbs.ServerCoordinationServiceServer = &workerServiceServer{} workerFilterSelectionFn = workerFilterSelector - connectionRouteFn = singleHopConnectionRoute + // connectionRouteFn returns a route to the egress worker. If the requester + // is the egress worker a route of length 1 is returned. A route of + // length 0 is never returned unless there is an error. + connectionRouteFn = singleHopConnectionRoute // getProtocolContext populates the protocol specific context fields // depending on the protocol used to for the boundary connection. Defaults @@ -304,7 +307,7 @@ func workerFilterSelector(sessionInfo *session.Session) string { } // noProtocolContext doesn't provide any protocol context since tcp doesn't need any -func noProtocolContext(context.Context, *session.Repository, *pbs.AuthorizeConnectionRequest) (*anypb.Any, error) { +func noProtocolContext(context.Context, *session.Repository, *server.Repository, common.WorkerAuthRepoStorageFactory, *pbs.AuthorizeConnectionRequest, []string) (*anypb.Any, error) { return nil, nil } @@ -527,7 +530,7 @@ func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs ConnectionsLeft: authzSummary.ConnectionLimit, Route: route, } - if pc, err := getProtocolContext(ctx, sessionRepo, req); err != nil { + if pc, err := getProtocolContext(ctx, sessionRepo, serversRepo, ws.workerAuthRepoFn, req, route); err != nil { return nil, err } else { ret.ProtocolContext = pc diff --git a/internal/daemon/worker/handler.go b/internal/daemon/worker/handler.go index 21a58172b4..23a256e18b 100644 --- a/internal/daemon/worker/handler.go +++ b/internal/daemon/worker/handler.go @@ -2,7 +2,7 @@ package worker import ( "context" - "errors" + stderrors "errors" "fmt" "io" "net" @@ -16,10 +16,15 @@ import ( "github.com/hashicorp/boundary/internal/daemon/worker/internal/metric" proxyHandlers "github.com/hashicorp/boundary/internal/daemon/worker/proxy" "github.com/hashicorp/boundary/internal/daemon/worker/session" + "github.com/hashicorp/boundary/internal/errors" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" "github.com/hashicorp/boundary/internal/observability/event" "github.com/hashicorp/boundary/internal/proxy" + "github.com/hashicorp/boundary/internal/util" "github.com/hashicorp/go-secure-stdlib/listenerutil" + "github.com/hashicorp/nodeenrollment" + "github.com/hashicorp/nodeenrollment/types" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/timestamppb" "nhooyr.io/websocket" @@ -62,7 +67,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa return func(wr http.ResponseWriter, r *http.Request) { ctx := r.Context() if r.TLS == nil { - event.WriteError(ctx, op, errors.New("no request tls information found")) + event.WriteError(ctx, op, stderrors.New("no request tls information found")) wr.WriteHeader(http.StatusInternalServerError) return } @@ -78,7 +83,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa } } if sessionId == "" { - event.WriteError(ctx, op, errors.New("no session id could be found in peer certificates")) + event.WriteError(ctx, op, stderrors.New("no session id could be found in peer certificates")) wr.WriteHeader(http.StatusInternalServerError) return } @@ -108,7 +113,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa sess := sessionManager.Get(sessionId) if sess == nil { - event.WriteError(ctx, op, errors.New("session not found locally"), event.WithInfo("session_id", sessionId)) + event.WriteError(ctx, op, stderrors.New("session not found locally"), event.WithInfo("session_id", sessionId)) wr.WriteHeader(http.StatusInternalServerError) return } @@ -137,7 +142,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa return } if len(handshake.GetTofuToken()) < 20 { - event.WriteError(ctx, op, errors.New("invalid tofu token")) + event.WriteError(ctx, op, stderrors.New("invalid tofu token")) if err = conn.Close(websocket.StatusUnsupportedData, "invalid tofu token"); err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection")) } @@ -147,12 +152,12 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa if handshake.Command == proxy.HANDSHAKECOMMAND_HANDSHAKECOMMAND_SESSION_CANCEL { if err := sess.RequestCancel(ctx); err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("unable to cancel session")) - if err = conn.Close(websocket.StatusInternalError, "unable to cancel session"); err != nil && !errors.Is(err, io.EOF) { + if err = conn.Close(websocket.StatusInternalError, "unable to cancel session"); err != nil && !stderrors.Is(err, io.EOF) { event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection")) } return } - if err = conn.Close(websocket.StatusNormalClosure, "session canceled"); err != nil && !errors.Is(err, io.EOF) { + if err = conn.Close(websocket.StatusNormalClosure, "session canceled"); err != nil && !stderrors.Is(err, io.EOF) { event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection")) } return @@ -160,7 +165,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa if sess.GetTofuToken() != "" { if sess.GetTofuToken() != handshake.GetTofuToken() { - event.WriteError(ctx, op, errors.New("WARNING: mismatched tofu token"), event.WithInfo("session_id", sessionId)) + event.WriteError(ctx, op, stderrors.New("WARNING: mismatched tofu token"), event.WithInfo("session_id", sessionId)) if err = conn.Close(websocket.StatusPolicyViolation, "tofu token not allowed"); err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection")) } @@ -168,7 +173,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa } } else { if sess.GetStatus() != pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING { - event.WriteError(ctx, op, errors.New("no tofu token but not in correct session state"), event.WithInfo("session_id", sessionId)) + event.WriteError(ctx, op, stderrors.New("no tofu token but not in correct session state"), event.WithInfo("session_id", sessionId)) if err = conn.Close(websocket.StatusInternalError, "refusing to activate session"); err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection")) } @@ -188,7 +193,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa } if w.LastStatusSuccess() == nil || w.LastStatusSuccess().WorkerId == "" { - event.WriteError(ctx, op, errors.New("worker id is empty")) + event.WriteError(ctx, op, stderrors.New("worker id is empty")) if err = conn.Close(websocket.StatusInternalError, "worker id is empty"); err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection")) } @@ -279,7 +284,12 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa event.WriteError(ctx, op, err) return } - runProxy, err := handleProxyFn(ctx, cc, pDialer, acResp.GetConnectionId(), protocolCtx) + decryptFn, err := w.credDecryptFn(ctx) + if err != nil { + conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "error getting decryption function") + event.WriteError(ctx, op, err) + } + runProxy, err := handleProxyFn(ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx) if err != nil { conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "unable to setup proxying") event.WriteError(ctx, op, err) @@ -310,6 +320,32 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa }, nil } +// credDecryptFn returns a DecryptFn if the worker is a pki worker with +// WorkerAuthStorage defined. An error is returned if there is an error +// loading the node credentials. +func (w *Worker) credDecryptFn(ctx context.Context) (proxyHandlers.DecryptFn, error) { + const op = "worker.(*Worker).credDecryptFn" + if w.WorkerAuthStorage == nil { + return nil, nil + } + was := w.WorkerAuthStorage + var opts []nodeenrollment.Option + if !util.IsNil(w.conf.WorkerAuthStorageKms) { + opts = append(opts, nodeenrollment.WithWrapper(w.conf.WorkerAuthStorageKms)) + } + nodeCreds, err := types.LoadNodeCredentials(ctx, was, nodeenrollment.CurrentId, opts...) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + + return func(ctx context.Context, from []byte, to proto.Message) error { + if err := nodeenrollment.DecryptMessage(ctx, from, nodeCreds, to); err != nil { + return errors.Wrap(ctx, err, op) + } + return nil + }, nil +} + func (w *Worker) wrapGenericHandler(h http.Handler, _ HandlerProperties) http.Handler { return http.HandlerFunc(func(wr http.ResponseWriter, r *http.Request) { // Set the Cache-Control header for all responses returned diff --git a/internal/daemon/worker/proxy/proxy.go b/internal/daemon/worker/proxy/proxy.go index a208d627a3..4828807da5 100644 --- a/internal/daemon/worker/proxy/proxy.go +++ b/internal/daemon/worker/proxy/proxy.go @@ -28,6 +28,9 @@ var ( GetHandler = tcpOnly ) +// DecryptFn decrypts the provided bytes into a proto.Message +type DecryptFn func(ctx context.Context, from []byte, to proto.Message) error + // ProxyConnFn is called after the call to ConnectConnection on the cluster. // ProxyConnFn blocks until the specific request that is being proxied is finished type ProxyConnFn func(ctx context.Context) @@ -37,7 +40,7 @@ type ProxyConnFn func(ctx context.Context) // be nil. If there is no error ProxyConnFn must be set. When Handler has // returned, it is expected that the initial connection to the endpoint has been // established. -type Handler func(context.Context, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) +type Handler func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) func RegisterHandler(protocol string, handler Handler) error { _, loaded := handlers.LoadOrStore(protocol, handler) diff --git a/internal/daemon/worker/proxy/proxy_test.go b/internal/daemon/worker/proxy/proxy_test.go index 04522a327d..876e07648c 100644 --- a/internal/daemon/worker/proxy/proxy_test.go +++ b/internal/daemon/worker/proxy/proxy_test.go @@ -14,7 +14,7 @@ import ( func TestRegisterHandler(t *testing.T) { assert, require := assert.New(t), require.New(t) - fn := func(context.Context, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) { + fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) { return nil, nil } oldHandler := handlers @@ -37,7 +37,7 @@ func TestRegisterHandler(t *testing.T) { func TestAlwaysTcpGetHandler(t *testing.T) { assert, require := assert.New(t), require.New(t) - fn := func(context.Context, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) { + fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) { return nil, nil } oldHandler := handlers @@ -45,9 +45,7 @@ func TestAlwaysTcpGetHandler(t *testing.T) { handlers = oldHandler }) handlers = sync.Map{} - _, err := tcpOnly("wid", nil) - require.Error(err) assert.ErrorIs(err, ErrUnknownProtocol) require.NoError(RegisterHandler("tcp", fn)) diff --git a/internal/daemon/worker/proxy/tcp/tcp.go b/internal/daemon/worker/proxy/tcp/tcp.go index f83505d0e8..02099789a0 100644 --- a/internal/daemon/worker/proxy/tcp/tcp.go +++ b/internal/daemon/worker/proxy/tcp/tcp.go @@ -24,7 +24,7 @@ func init() { // handleProxy returns a ProxyConnFn which starts the copy between the // connections and blocks until an error (EOF on happy path) is received on // either connection. -func handleProxy(ctx context.Context, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any) (proxy.ProxyConnFn, error) { +func handleProxy(ctx context.Context, _ proxy.DecryptFn, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any) (proxy.ProxyConnFn, error) { const op = "tcp.HandleProxy" switch { case conn == nil: diff --git a/internal/daemon/worker/proxy/tcp/tcp_test.go b/internal/daemon/worker/proxy/tcp/tcp_test.go index b1fb2b5688..4db2c731d8 100644 --- a/internal/daemon/worker/proxy/tcp/tcp_test.go +++ b/internal/daemon/worker/proxy/tcp/tcp_test.go @@ -85,7 +85,7 @@ func TestHandleProxy_Errors(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - fn, err := handleProxy(context.Background(), tc.conn, tc.dialer, tc.connId, tc.protocolCtx) + fn, err := handleProxy(context.Background(), nil, tc.conn, tc.dialer, tc.connId, tc.protocolCtx) if tc.wantError { assert.Error(t, err) assert.Nil(t, fn) @@ -163,7 +163,7 @@ func TestHandleTcpProxyV1(t *testing.T) { conn := websocket.NetConn(ctx, proxyConn, websocket.MessageBinary) go func() { - fn, err := handleProxy(ctx, conn, tDial, resp.GetConnectionId(), resp.GetProtocolContext()) + fn, err := handleProxy(ctx, nil, conn, tDial, resp.GetConnectionId(), resp.GetProtocolContext()) t.Cleanup(func() { // Use of the t.Cleanup is so we can check the state of the returned // error since it isn't valid to call `t.FailNow()` from a goroutine.