From 4114b415277c92bdded76eae9b137c3ec9f36d3c Mon Sep 17 00:00:00 2001 From: Irena Rindos Date: Fri, 3 Feb 2023 11:40:41 -0500 Subject: [PATCH] Filter managed workers from egress workers when the host address is unsafe (#2899) --- .../daemon/cluster/handlers/worker_service.go | 69 ++++++++++++------- .../cluster/handlers/worker_service_test.go | 23 +++++-- internal/daemon/common/const.go | 1 + .../handlers/targets/target_service.go | 48 +++++-------- .../handlers/targets/target_service_test.go | 2 +- 5 files changed, 79 insertions(+), 64 deletions(-) diff --git a/internal/daemon/cluster/handlers/worker_service.go b/internal/daemon/cluster/handlers/worker_service.go index d8dc8acc51..a3e55db90c 100644 --- a/internal/daemon/cluster/handlers/worker_service.go +++ b/internal/daemon/cluster/handlers/worker_service.go @@ -2,12 +2,12 @@ package handlers import ( "context" - "errors" "fmt" "sync" "sync/atomic" "time" + dcommon "github.com/hashicorp/boundary/internal/daemon/common" "github.com/hashicorp/boundary/internal/daemon/controller/common" "github.com/hashicorp/boundary/internal/daemon/controller/handlers" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" @@ -24,8 +24,6 @@ import ( "google.golang.org/protobuf/types/known/anypb" ) -const ManagedWorkerTagKey = "boundary.cloud.hashicorp.com:managed" - type workerServiceServer struct { pbs.UnsafeServerCoordinationServiceServer pbs.UnsafeSessionServiceServer @@ -44,7 +42,7 @@ var ( _ pbs.SessionServiceServer = &workerServiceServer{} _ pbs.ServerCoordinationServiceServer = &workerServiceServer{} - workerFilterSelectionFn = workerFilterSelector + workerFilterSelectionFn = egressFilterSelector // connectionRouteFn returns a route to the egress worker. If the requester // is the egress worker a route of length 1 is returned. A route of // length 0 is never returned unless there is an error. @@ -59,8 +57,8 @@ var ( ) // singleHopConnectionRoute returns a route consisting of the singlehop worker (the root worker id) -func singleHopConnectionRoute(_ context.Context, rootInfo server.RootInfo, _ *session.AuthzSummary, _ *server.Repository, _ common.Downstreamers) ([]string, error) { - return []string{rootInfo.RootId}, nil +func singleHopConnectionRoute(_ context.Context, w *server.Worker, _ *session.Session, _ *session.AuthzSummary, _ *server.Repository, _ common.Downstreamers) ([]string, error) { + return []string{w.GetPublicId()}, nil } func NewWorkerServiceServer( @@ -289,7 +287,7 @@ func (ws *workerServiceServer) ListHcpbWorkers(ctx context.Context, req *pbs.Lis resp.Workers = make([]*pbs.WorkerInfo, 0, len(workers)) for _, worker := range workers { - vals := worker.CanonicalTags()[ManagedWorkerTagKey] + vals := worker.CanonicalTags()[dcommon.ManagedWorkerTag] if len(vals) == 1 && vals[0] == "true" { resp.Workers = append(resp.Workers, &pbs.WorkerInfo{ Id: worker.GetPublicId(), @@ -303,7 +301,7 @@ func (ws *workerServiceServer) ListHcpbWorkers(ctx context.Context, req *pbs.Lis // Single-hop filter lookup. We have either an egress filter or worker filter to use, if any // Used to verify that the worker serving this session to a client matches this filter -func workerFilterSelector(sessionInfo *session.Session) string { +func egressFilterSelector(sessionInfo *session.Session) string { if sessionInfo.EgressWorkerFilter != "" { return sessionInfo.EgressWorkerFilter } else if sessionInfo.WorkerFilter != "" { @@ -317,26 +315,17 @@ func noProtocolContext(context.Context, *session.Repository, *server.Repository, return nil, nil } -func lookupSessionWorkerFilter(ctx context.Context, sessionInfo *session.Session, ws *workerServiceServer, +func lookupSessionWorkerFilter(ctx context.Context, sessionInfo *session.Session, authzSummary *session.AuthzSummary, ws *workerServiceServer, req *pbs.LookupSessionRequest, ) error { const op = "workers.lookupSessionEgressWorkerFilter" - filter := workerFilterSelectionFn(sessionInfo) - if filter == "" { - return nil - } - - if req.WorkerId == "" { - event.WriteError(ctx, op, errors.New("worker filter enabled for session but got no id information from worker")) - return status.Errorf(codes.Internal, "Did not receive worker id when looking up session but filtering is enabled") - } serversRepo, err := ws.serversRepoFn() if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error getting server repo")) return status.Errorf(codes.Internal, "Error acquiring server repo when looking up session: %v", err) } - w, err := serversRepo.LookupWorker(ctx, req.WorkerId) + w, err := serversRepo.LookupWorker(ctx, req.GetWorkerId()) if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error looking up worker", "worker_id", req.WorkerId)) return status.Errorf(codes.Internal, "Error looking up worker: %v", err) @@ -345,6 +334,20 @@ func lookupSessionWorkerFilter(ctx context.Context, sessionInfo *session.Session event.WriteError(ctx, op, err, event.WithInfoMsg("error looking up worker", "worker_id", req.WorkerId)) return status.Errorf(codes.Internal, "Worker not found") } + + filter := workerFilterSelectionFn(sessionInfo) + if filter == "" { + // Verify that this ingress worker can build a route to the endpoint safely + // While the AuthorizeSession may have done a similar check, this makes sure + // we can select a worker for egress that wouldn't potentially grant access + // to a private ip address in the network of the boundary deployment in the + // case of hcp. + if _, err := connectionRouteFn(ctx, w, sessionInfo, authzSummary, serversRepo, ws.downstreams); err != nil { + return status.Errorf(codes.Internal, "error calculating route to endpoint: %v", err) + } + return nil + } + // Build the map for filtering. tagMap := w.CanonicalTags() @@ -366,12 +369,25 @@ func lookupSessionWorkerFilter(ctx context.Context, sessionInfo *session.Session return handlers.ApiErrorWithCodeAndMessage(codes.FailedPrecondition, "Worker filter expression precludes this worker from serving this session") } + // Verify that this ingress worker can build a route to the endpoint safely + // While the AuthorizeSession may have done a similar check, this makes sure + // we can select a worker for egress that wouldn't potentially grant access + // to a private ip address in the network of the boundary deployment in the + // case of hcp. + if _, err = connectionRouteFn(ctx, w, sessionInfo, authzSummary, serversRepo, ws.downstreams); err != nil { + return status.Errorf(codes.Internal, "error calculating route to endpoint: %v", err) + } + return nil } func (ws *workerServiceServer) LookupSession(ctx context.Context, req *pbs.LookupSessionRequest) (*pbs.LookupSessionResponse, error) { const op = "workers.(workerServiceServer).LookupSession" + if req.WorkerId == "" { + return nil, status.Errorf(codes.InvalidArgument, "Did not receive worker id when looking up session") + } + sessRepo, err := ws.sessionRepoFn() if err != nil { return nil, status.Errorf(codes.Internal, "Error getting session repo: %v", err) @@ -388,7 +404,7 @@ func (ws *workerServiceServer) LookupSession(ctx context.Context, req *pbs.Looku return nil, status.Error(codes.Internal, "Empty session states during lookup.") } - err = lookupSessionWorkerFilter(ctx, sessionInfo, ws, req) + err = lookupSessionWorkerFilter(ctx, sessionInfo, authzSummary, ws, req) if err != nil { return nil, err } @@ -509,7 +525,7 @@ func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs return nil, status.Errorf(codes.NotFound, "worker not found with name %q", req.GetWorkerId()) } - connectionInfo, connStates, authzSummary, err := session.AuthorizeConnection(ctx, sessionRepo, connectionRepo, req.GetSessionId(), w.GetPublicId()) + connectionInfo, connStates, err := connectionRepo.AuthorizeConnection(ctx, req.GetSessionId(), w.GetPublicId()) if err != nil { return nil, err } @@ -520,12 +536,15 @@ func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs return nil, status.Error(codes.Internal, "Invalid connection state in authorize response.") } - rootInfo := server.RootInfo{ - RootId: req.GetWorkerId(), - RootVer: w.ReleaseVersion, + sessInfo, authzSummary, err := sessionRepo.LookupSession(ctx, req.GetSessionId()) + if err != nil { + return nil, err + } + if sessInfo == nil { + return nil, status.Errorf(codes.Internal, "Invalid session info in lookup session response") } - route, err := connectionRouteFn(ctx, rootInfo, authzSummary, serversRepo, ws.downstreams) + route, err := connectionRouteFn(ctx, w, sessInfo, authzSummary, serversRepo, ws.downstreams) if err != nil { return nil, status.Errorf(codes.Internal, "error getting route to egress worker: %v", err) } diff --git a/internal/daemon/cluster/handlers/worker_service_test.go b/internal/daemon/cluster/handlers/worker_service_test.go index 07447ee610..ec1f6c961e 100644 --- a/internal/daemon/cluster/handlers/worker_service_test.go +++ b/internal/daemon/cluster/handlers/worker_service_test.go @@ -12,6 +12,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/hashicorp/boundary/internal/authtoken" credstatic "github.com/hashicorp/boundary/internal/credential/static" + dcommon "github.com/hashicorp/boundary/internal/daemon/common" "github.com/hashicorp/boundary/internal/db" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" "github.com/hashicorp/boundary/internal/host/static" @@ -143,6 +144,12 @@ func TestLookupSession(t *testing.T) { s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connectionRepoFn, nil, new(sync.Map), kms, new(atomic.Int64)) require.NotNil(t, s) + oldFn := connectionRouteFn + connectionRouteFn = singleHopConnectionRoute + t.Cleanup(func() { + connectionRouteFn = oldFn + }) + cases := []struct { name string wantErr bool @@ -154,6 +161,7 @@ func TestLookupSession(t *testing.T) { name: "Invalid session id", req: &pbs.LookupSessionRequest{ SessionId: "s_fakesession", + WorkerId: worker1.GetPublicId(), }, wantErr: true, wantErrMsg: "rpc error: code = PermissionDenied desc = Unknown session ID.", @@ -164,7 +172,7 @@ func TestLookupSession(t *testing.T) { SessionId: sessWithWorkerFilter.PublicId, }, wantErr: true, - wantErrMsg: "rpc error: code = Internal desc = Did not receive worker id when looking up session but filtering is enabled", + wantErrMsg: "rpc error: code = InvalidArgument desc = Did not receive worker id when looking up session", }, { name: "nonexistant worker id", @@ -179,6 +187,7 @@ func TestLookupSession(t *testing.T) { name: "Valid", req: &pbs.LookupSessionRequest{ SessionId: sess.PublicId, + WorkerId: worker1.GetPublicId(), }, want: &pbs.LookupSessionResponse{ Authorization: &targets.SessionAuthorizationData{ @@ -224,6 +233,7 @@ func TestLookupSession(t *testing.T) { name: "Valid-with-worker-creds", req: &pbs.LookupSessionRequest{ SessionId: sessWithCreds.PublicId, + WorkerId: worker1.GetPublicId(), }, want: &pbs.LookupSessionResponse{ Authorization: &targets.SessionAuthorizationData{ @@ -256,6 +266,7 @@ func TestLookupSession(t *testing.T) { assert.Equal(tc.wantErrMsg, err.Error()) return } + require.NoError(err) assert.Empty( cmp.Diff( tc.want, @@ -547,22 +558,22 @@ func TestHcpbWorkers(t *testing.T) { liveDur.Store(int64(1 * time.Second)) // Stale/unalive kms worker aren't expected... - server.TestKmsWorker(t, conn, wrapper, server.WithWorkerTags(&server.Tag{Key: ManagedWorkerTagKey, Value: "true"}), + server.TestKmsWorker(t, conn, wrapper, server.WithWorkerTags(&server.Tag{Key: dcommon.ManagedWorkerTag, Value: "true"}), server.WithAddress("old.kms.1")) // Sleep + 500ms longer than the liveness duration. time.Sleep(time.Duration(liveDur.Load()) + time.Second) - server.TestKmsWorker(t, conn, wrapper, server.WithWorkerTags(&server.Tag{Key: ManagedWorkerTagKey, Value: "true"}), + server.TestKmsWorker(t, conn, wrapper, server.WithWorkerTags(&server.Tag{Key: dcommon.ManagedWorkerTag, Value: "true"}), server.WithAddress("kms.1")) - server.TestKmsWorker(t, conn, wrapper, server.WithWorkerTags(&server.Tag{Key: ManagedWorkerTagKey, Value: "true"}), + server.TestKmsWorker(t, conn, wrapper, server.WithWorkerTags(&server.Tag{Key: dcommon.ManagedWorkerTag, Value: "true"}), server.WithAddress("kms.2")) // Shutdown workers will be removed from routes and sessions, but still returned // to downstream workers - server.TestKmsWorker(t, conn, wrapper, server.WithWorkerTags(&server.Tag{Key: ManagedWorkerTagKey, Value: "true"}), + server.TestKmsWorker(t, conn, wrapper, server.WithWorkerTags(&server.Tag{Key: dcommon.ManagedWorkerTag, Value: "true"}), server.WithAddress("shutdown.kms.3"), server.WithOperationalState(server.ShutdownOperationalState.String())) // PKI workers aren't expected - server.TestPkiWorker(t, conn, wrapper, server.WithWorkerTags(&server.Tag{Key: ManagedWorkerTagKey, Value: "true"})) + server.TestPkiWorker(t, conn, wrapper, server.WithWorkerTags(&server.Tag{Key: dcommon.ManagedWorkerTag, Value: "true"})) s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connectionRepoFn, nil, new(sync.Map), kmsCache, &liveDur) require.NotNil(t, s) diff --git a/internal/daemon/common/const.go b/internal/daemon/common/const.go index 013d89603c..df81fd989d 100644 --- a/internal/daemon/common/const.go +++ b/internal/daemon/common/const.go @@ -3,4 +3,5 @@ package common const ( ReverseGrpcConnectionAlpnValue = "the-downstream-dialer-plays-an-uno-reverse-card" DataPlaneProxyAlpnValue = "i-herd-you-like-proxies-so-i-put-a-proxy-in-your-proxy" + ManagedWorkerTag = "boundary.cloud.hashicorp.com:managed" ) diff --git a/internal/daemon/controller/handlers/targets/target_service.go b/internal/daemon/controller/handlers/targets/target_service.go index e6062712f9..e44010c083 100644 --- a/internal/daemon/controller/handlers/targets/target_service.go +++ b/internal/daemon/controller/handlers/targets/target_service.go @@ -622,9 +622,9 @@ func (s Service) RemoveTargetCredentialSources(ctx context.Context, req *pbs.Rem return &pbs.RemoveTargetCredentialSourcesResponse{Item: item}, nil } -// If set, use the worker_filter or egress_worker_filter to filtere the selected workers +// If set, use the worker_filter or egress_worker_filter to filter the selected workers // and ensure we have workers available to service this request. -func AuthorizeSessionWithWorkerFilter(_ context.Context, t target.Target, selectedWorkers wl.WorkerList, _ common.Downstreamers) (wl.WorkerList, error) { +func AuthorizeSessionWithWorkerFilter(_ context.Context, t target.Target, selectedWorkers wl.WorkerList, _ string, _ common.Downstreamers) (wl.WorkerList, error) { if len(selectedWorkers) > 0 { var eval *bexpr.Evaluator var err error @@ -728,26 +728,6 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession return nil, err } - // First ensure we can actually service a request, that is, we have workers - // available (after any filtering). WorkerInfo only contains the address; - // worker IDs below is used to contain their IDs in the same order. This is - // used to fetch tags for filtering. But we avoid allocation unless we - // actually need it. - selectedWorkers, err := serversRepo.ListWorkers(ctx, []string{scope.Global.String()}, server.WithLiveness(time.Duration(s.workerStatusGracePeriod.Load()))) - if err != nil { - return nil, err - } - - selectedWorkers, err = AuthorizeSessionWorkerFilterFn(ctx, t, selectedWorkers, s.downstreams) - if err != nil { - return nil, err - } - - // Randomize the workers - rand.Shuffle(len(selectedWorkers), func(i, j int) { - selectedWorkers[i], selectedWorkers[j] = selectedWorkers[j], selectedWorkers[i] - }) - p := strconv.FormatUint(uint64(t.GetDefaultPort()), 10) var h, hostId, hostSetId string @@ -852,18 +832,22 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession Host: net.JoinHostPort(h, p), } - for _, extraFilter := range ExtraWorkerFilters { - selectedWorkers, err = extraFilter(ctx, selectedWorkers, h, p) - if err != nil { - return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error executing extra worker filter")) - } - if len(selectedWorkers) == 0 { - return nil, handlers.ApiErrorWithCodeAndMessage( - codes.FailedPrecondition, - "No workers are available to handle this session, or all have been filtered.") - } + // Get workers and filter down to ones that can service this request + selectedWorkers, err := serversRepo.ListWorkers(ctx, []string{scope.Global.String()}, server.WithLiveness(time.Duration(s.workerStatusGracePeriod.Load()))) + if err != nil { + return nil, err } + selectedWorkers, err = AuthorizeSessionWorkerFilterFn(ctx, t, selectedWorkers, h, s.downstreams) + if err != nil { + return nil, err + } + + // Randomize the workers + rand.Shuffle(len(selectedWorkers), func(i, j int) { + selectedWorkers[i], selectedWorkers[j] = selectedWorkers[j], selectedWorkers[i] + }) + var vaultReqs []credential.Request var staticIds []string var dynCreds []*session.DynamicCredential diff --git a/internal/daemon/controller/handlers/targets/target_service_test.go b/internal/daemon/controller/handlers/targets/target_service_test.go index 6a2d3cb500..b8e8a8d876 100644 --- a/internal/daemon/controller/handlers/targets/target_service_test.go +++ b/internal/daemon/controller/handlers/targets/target_service_test.go @@ -114,7 +114,7 @@ func TestWorkerList_EgressFilter(t *testing.T) { if len(tc.filter) > 0 { target.EgressWorkerFilter = tc.filter } - out, err := AuthorizeSessionWithWorkerFilter(ctx, target, tc.in, nil) + out, err := AuthorizeSessionWithWorkerFilter(ctx, target, tc.in, "", nil) if tc.errContains != "" { assert.Contains(err.Error(), tc.errContains) assert.Nil(out)