Provide a worker decryption function to proxy handlers (#2784)

pull/2785/head
Todd 3 years ago committed by GitHub
parent 0b87acb896
commit 38b74ce897
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -45,7 +45,10 @@ var (
_ pbs.ServerCoordinationServiceServer = &workerServiceServer{}
workerFilterSelectionFn = workerFilterSelector
connectionRouteFn = singleHopConnectionRoute
// connectionRouteFn returns a route to the egress worker. If the requester
// is the egress worker a route of length 1 is returned. A route of
// length 0 is never returned unless there is an error.
connectionRouteFn = singleHopConnectionRoute
// getProtocolContext populates the protocol specific context fields
// depending on the protocol used to for the boundary connection. Defaults
@ -304,7 +307,7 @@ func workerFilterSelector(sessionInfo *session.Session) string {
}
// noProtocolContext doesn't provide any protocol context since tcp doesn't need any
func noProtocolContext(context.Context, *session.Repository, *pbs.AuthorizeConnectionRequest) (*anypb.Any, error) {
func noProtocolContext(context.Context, *session.Repository, *server.Repository, common.WorkerAuthRepoStorageFactory, *pbs.AuthorizeConnectionRequest, []string) (*anypb.Any, error) {
return nil, nil
}
@ -527,7 +530,7 @@ func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs
ConnectionsLeft: authzSummary.ConnectionLimit,
Route: route,
}
if pc, err := getProtocolContext(ctx, sessionRepo, req); err != nil {
if pc, err := getProtocolContext(ctx, sessionRepo, serversRepo, ws.workerAuthRepoFn, req, route); err != nil {
return nil, err
} else {
ret.ProtocolContext = pc

@ -2,7 +2,7 @@ package worker
import (
"context"
"errors"
stderrors "errors"
"fmt"
"io"
"net"
@ -16,10 +16,15 @@ import (
"github.com/hashicorp/boundary/internal/daemon/worker/internal/metric"
proxyHandlers "github.com/hashicorp/boundary/internal/daemon/worker/proxy"
"github.com/hashicorp/boundary/internal/daemon/worker/session"
"github.com/hashicorp/boundary/internal/errors"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/boundary/internal/observability/event"
"github.com/hashicorp/boundary/internal/proxy"
"github.com/hashicorp/boundary/internal/util"
"github.com/hashicorp/go-secure-stdlib/listenerutil"
"github.com/hashicorp/nodeenrollment"
"github.com/hashicorp/nodeenrollment/types"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb"
"nhooyr.io/websocket"
@ -62,7 +67,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
return func(wr http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if r.TLS == nil {
event.WriteError(ctx, op, errors.New("no request tls information found"))
event.WriteError(ctx, op, stderrors.New("no request tls information found"))
wr.WriteHeader(http.StatusInternalServerError)
return
}
@ -78,7 +83,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
}
}
if sessionId == "" {
event.WriteError(ctx, op, errors.New("no session id could be found in peer certificates"))
event.WriteError(ctx, op, stderrors.New("no session id could be found in peer certificates"))
wr.WriteHeader(http.StatusInternalServerError)
return
}
@ -108,7 +113,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
sess := sessionManager.Get(sessionId)
if sess == nil {
event.WriteError(ctx, op, errors.New("session not found locally"), event.WithInfo("session_id", sessionId))
event.WriteError(ctx, op, stderrors.New("session not found locally"), event.WithInfo("session_id", sessionId))
wr.WriteHeader(http.StatusInternalServerError)
return
}
@ -137,7 +142,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
return
}
if len(handshake.GetTofuToken()) < 20 {
event.WriteError(ctx, op, errors.New("invalid tofu token"))
event.WriteError(ctx, op, stderrors.New("invalid tofu token"))
if err = conn.Close(websocket.StatusUnsupportedData, "invalid tofu token"); err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection"))
}
@ -147,12 +152,12 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
if handshake.Command == proxy.HANDSHAKECOMMAND_HANDSHAKECOMMAND_SESSION_CANCEL {
if err := sess.RequestCancel(ctx); err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("unable to cancel session"))
if err = conn.Close(websocket.StatusInternalError, "unable to cancel session"); err != nil && !errors.Is(err, io.EOF) {
if err = conn.Close(websocket.StatusInternalError, "unable to cancel session"); err != nil && !stderrors.Is(err, io.EOF) {
event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection"))
}
return
}
if err = conn.Close(websocket.StatusNormalClosure, "session canceled"); err != nil && !errors.Is(err, io.EOF) {
if err = conn.Close(websocket.StatusNormalClosure, "session canceled"); err != nil && !stderrors.Is(err, io.EOF) {
event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection"))
}
return
@ -160,7 +165,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
if sess.GetTofuToken() != "" {
if sess.GetTofuToken() != handshake.GetTofuToken() {
event.WriteError(ctx, op, errors.New("WARNING: mismatched tofu token"), event.WithInfo("session_id", sessionId))
event.WriteError(ctx, op, stderrors.New("WARNING: mismatched tofu token"), event.WithInfo("session_id", sessionId))
if err = conn.Close(websocket.StatusPolicyViolation, "tofu token not allowed"); err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection"))
}
@ -168,7 +173,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
}
} else {
if sess.GetStatus() != pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING {
event.WriteError(ctx, op, errors.New("no tofu token but not in correct session state"), event.WithInfo("session_id", sessionId))
event.WriteError(ctx, op, stderrors.New("no tofu token but not in correct session state"), event.WithInfo("session_id", sessionId))
if err = conn.Close(websocket.StatusInternalError, "refusing to activate session"); err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection"))
}
@ -188,7 +193,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
}
if w.LastStatusSuccess() == nil || w.LastStatusSuccess().WorkerId == "" {
event.WriteError(ctx, op, errors.New("worker id is empty"))
event.WriteError(ctx, op, stderrors.New("worker id is empty"))
if err = conn.Close(websocket.StatusInternalError, "worker id is empty"); err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection"))
}
@ -279,7 +284,12 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
event.WriteError(ctx, op, err)
return
}
runProxy, err := handleProxyFn(ctx, cc, pDialer, acResp.GetConnectionId(), protocolCtx)
decryptFn, err := w.credDecryptFn(ctx)
if err != nil {
conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "error getting decryption function")
event.WriteError(ctx, op, err)
}
runProxy, err := handleProxyFn(ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx)
if err != nil {
conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "unable to setup proxying")
event.WriteError(ctx, op, err)
@ -310,6 +320,32 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
}, nil
}
// credDecryptFn returns a DecryptFn if the worker is a pki worker with
// WorkerAuthStorage defined. An error is returned if there is an error
// loading the node credentials.
func (w *Worker) credDecryptFn(ctx context.Context) (proxyHandlers.DecryptFn, error) {
const op = "worker.(*Worker).credDecryptFn"
if w.WorkerAuthStorage == nil {
return nil, nil
}
was := w.WorkerAuthStorage
var opts []nodeenrollment.Option
if !util.IsNil(w.conf.WorkerAuthStorageKms) {
opts = append(opts, nodeenrollment.WithWrapper(w.conf.WorkerAuthStorageKms))
}
nodeCreds, err := types.LoadNodeCredentials(ctx, was, nodeenrollment.CurrentId, opts...)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return func(ctx context.Context, from []byte, to proto.Message) error {
if err := nodeenrollment.DecryptMessage(ctx, from, nodeCreds, to); err != nil {
return errors.Wrap(ctx, err, op)
}
return nil
}, nil
}
func (w *Worker) wrapGenericHandler(h http.Handler, _ HandlerProperties) http.Handler {
return http.HandlerFunc(func(wr http.ResponseWriter, r *http.Request) {
// Set the Cache-Control header for all responses returned

@ -28,6 +28,9 @@ var (
GetHandler = tcpOnly
)
// DecryptFn decrypts the provided bytes into a proto.Message
type DecryptFn func(ctx context.Context, from []byte, to proto.Message) error
// ProxyConnFn is called after the call to ConnectConnection on the cluster.
// ProxyConnFn blocks until the specific request that is being proxied is finished
type ProxyConnFn func(ctx context.Context)
@ -37,7 +40,7 @@ type ProxyConnFn func(ctx context.Context)
// be nil. If there is no error ProxyConnFn must be set. When Handler has
// returned, it is expected that the initial connection to the endpoint has been
// established.
type Handler func(context.Context, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error)
type Handler func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error)
func RegisterHandler(protocol string, handler Handler) error {
_, loaded := handlers.LoadOrStore(protocol, handler)

@ -14,7 +14,7 @@ import (
func TestRegisterHandler(t *testing.T) {
assert, require := assert.New(t), require.New(t)
fn := func(context.Context, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) {
fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) {
return nil, nil
}
oldHandler := handlers
@ -37,7 +37,7 @@ func TestRegisterHandler(t *testing.T) {
func TestAlwaysTcpGetHandler(t *testing.T) {
assert, require := assert.New(t), require.New(t)
fn := func(context.Context, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) {
fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) {
return nil, nil
}
oldHandler := handlers
@ -45,9 +45,7 @@ func TestAlwaysTcpGetHandler(t *testing.T) {
handlers = oldHandler
})
handlers = sync.Map{}
_, err := tcpOnly("wid", nil)
require.Error(err)
assert.ErrorIs(err, ErrUnknownProtocol)
require.NoError(RegisterHandler("tcp", fn))

@ -24,7 +24,7 @@ func init() {
// handleProxy returns a ProxyConnFn which starts the copy between the
// connections and blocks until an error (EOF on happy path) is received on
// either connection.
func handleProxy(ctx context.Context, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any) (proxy.ProxyConnFn, error) {
func handleProxy(ctx context.Context, _ proxy.DecryptFn, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any) (proxy.ProxyConnFn, error) {
const op = "tcp.HandleProxy"
switch {
case conn == nil:

@ -85,7 +85,7 @@ func TestHandleProxy_Errors(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
fn, err := handleProxy(context.Background(), tc.conn, tc.dialer, tc.connId, tc.protocolCtx)
fn, err := handleProxy(context.Background(), nil, tc.conn, tc.dialer, tc.connId, tc.protocolCtx)
if tc.wantError {
assert.Error(t, err)
assert.Nil(t, fn)
@ -163,7 +163,7 @@ func TestHandleTcpProxyV1(t *testing.T) {
conn := websocket.NetConn(ctx, proxyConn, websocket.MessageBinary)
go func() {
fn, err := handleProxy(ctx, conn, tDial, resp.GetConnectionId(), resp.GetProtocolContext())
fn, err := handleProxy(ctx, nil, conn, tDial, resp.GetConnectionId(), resp.GetProtocolContext())
t.Cleanup(func() {
// Use of the t.Cleanup is so we can check the state of the returned
// error since it isn't valid to call `t.FailNow()` from a goroutine.

Loading…
Cancel
Save