From 19180af0ebdf5f22040b045d010aa8cedb858a80 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 26 Jan 2023 10:35:29 -0800 Subject: [PATCH] Fix target port handling (#2846) * Fix target port handling This fixes two issues that compounded on each other (see the Changelog update for more information): * The verification logic for hosts was not correct for update operations (in multiple ways) which meant that a host could be updated after creation to have a port. Targets had a previously fixed bug where they did not require a default port, which meant that ports could be used from hosts * A recent (unreleased) change had prioritized any port coming from the host over the default port, which would mean there was no way if a host had a port specified to use it in multiple targets. This fixes the update verification logic, and strikes a middle ground between breaking things and not by allowing existing addresses with ports to be used with targets but ignoring that port, instead requiring targets to have default port set at authorize time (currently there is backwards compat that does not require this due to the original optional port bug). * Update CHANGELOG.md Co-authored-by: Johan Brandhorst-Satzkorn --- CHANGELOG.md | 15 ++ globals/errors.go | 3 + globals/fields.go | 3 +- internal/cmd/commands/server/server.go | 4 +- internal/cmd/config/config.go | 2 +- .../credentialstore_service.go | 7 +- .../credentialstore_service_test.go | 2 +- .../controller/handlers/hosts/host_service.go | 52 +++--- .../handlers/hosts/host_service_test.go | 151 +++++++++++------ .../handlers/targets/target_service.go | 156 ++++++++++-------- .../targets/tcp/target_service_test.go | 8 +- .../handlers/workers/worker_service.go | 2 +- .../daemon/worker/controller_connection.go | 2 +- 13 files changed, 254 insertions(+), 153 deletions(-) create mode 100644 globals/errors.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 72753d6b0d..925e0d18d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,21 @@ Canonical reference for changes, improvements, and bugfixes for Boundary. ## 0.12.0 (2023/01/24) +### Deprecations/Changes + +* In Boundary 0.9.0, targets were updated to require a default port value. This + had been the original intention; it was a mistake that it was optional. + Unfortunately, due to a separate defect in the update verification logic for + static hosts, it was possible for a host to be updated (but not created) with + a port. This meant that targets could use ports attached to host addresses, + which was not the intention and leads to confusing behavior across different + installations. In this version, updating static hosts will no longer allow + ports to be part of the address; when authorizing a session, any port on such + a host will be ignored in favor of the default port on the target. In Boundary + 0.14.0, this will become an error instead. As a consequence, it means that the + fallback logic for targets that did not have a default port defined is no + longer in service; all targets must now have a default port defined. + ### New and Improved * Direct Address Targets: You can now set an address directly on a target, diff --git a/globals/errors.go b/globals/errors.go new file mode 100644 index 0000000000..a57c02a0b6 --- /dev/null +++ b/globals/errors.go @@ -0,0 +1,3 @@ +package globals + +const MissingPortErrStr = "missing port in address" diff --git a/globals/fields.go b/globals/fields.go index 63ad052300..acff59df20 100644 --- a/globals/fields.go +++ b/globals/fields.go @@ -93,5 +93,6 @@ const ( KeyVersionIdField = "key_version_id" CompletedCountField = "completed_count" TotalCountField = "total_count" - DirectlyConnectedDownstreamWorkers = "directly_connected_downstream_workers" + DirectlyConnectedDownstreamWorkersField = "directly_connected_downstream_workers" + AttributesAddressField = "attributes.address" ) diff --git a/internal/cmd/commands/server/server.go b/internal/cmd/commands/server/server.go index a088c89933..98c0a8c013 100644 --- a/internal/cmd/commands/server/server.go +++ b/internal/cmd/commands/server/server.go @@ -311,7 +311,7 @@ func (c *Command) Run(args []string) int { for _, upstream := range c.Config.Worker.InitialUpstreams { host, _, err := net.SplitHostPort(upstream) if err != nil { - if strings.Contains(err.Error(), "missing port in address") { + if strings.Contains(err.Error(), globals.MissingPortErrStr) { host = upstream } else { c.UI.Error(fmt.Errorf("Invalid worker upstream address %q: %w", upstream, err).Error()) @@ -355,7 +355,7 @@ func (c *Command) Run(args []string) int { } host, _, err := net.SplitHostPort(ln.Address) if err != nil { - if strings.Contains(err.Error(), "missing port in address") { + if strings.Contains(err.Error(), globals.MissingPortErrStr) { host = ln.Address } else { c.UI.Error(fmt.Errorf("Invalid cluster listener address %q: %w", ln.Address, err).Error()) diff --git a/internal/cmd/config/config.go b/internal/cmd/config/config.go index 40f434d74c..224a7226d5 100644 --- a/internal/cmd/config/config.go +++ b/internal/cmd/config/config.go @@ -1103,7 +1103,7 @@ func (c *Config) SetupWorkerInitialUpstreams() error { } // Best effort see if it's a domain name and if not assume it must match host, _, err := net.SplitHostPort(c.Worker.InitialUpstreams[0]) - if err != nil && strings.Contains(err.Error(), "missing port in address") { + if err != nil && strings.Contains(err.Error(), globals.MissingPortErrStr) { err = nil host = c.Worker.InitialUpstreams[0] } diff --git a/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go b/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go index 6b1c1f7c55..c0f91eccfa 100644 --- a/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go +++ b/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go @@ -32,7 +32,6 @@ import ( ) const ( - addressField = "attributes.address" vaultTokenField = "attributes.token" vaultTokenHmacField = "attributes.token_hmac" vaultWorkerFilterField = "attributes.worker_filter" @@ -793,7 +792,7 @@ func validateCreateRequest(ctx context.Context, req *pbs.CreateCredentialStoreRe } if attrs.GetAddress().GetValue() == "" { - badFields[addressField] = "Field required for creating a vault credential store." + badFields[globals.AttributesAddressField] = "Field required for creating a vault credential store." } if attrs.GetToken().GetValue() == "" { badFields[vaultTokenField] = "Field required for creating a vault credential store." @@ -842,9 +841,9 @@ func validateUpdateRequest(ctx context.Context, req *pbs.UpdateCredentialStoreRe } attrs := req.GetItem().GetVaultCredentialStoreAttributes() if attrs != nil { - if handlers.MaskContains(req.GetUpdateMask().GetPaths(), addressField) && + if handlers.MaskContains(req.GetUpdateMask().GetPaths(), globals.AttributesAddressField) && attrs.GetAddress().GetValue() == "" { - badFields[addressField] = "This is a required field and cannot be unset." + badFields[globals.AttributesAddressField] = "This is a required field and cannot be unset." } if handlers.MaskContains(req.GetUpdateMask().GetPaths(), vaultTokenField) && attrs.GetToken().GetValue() == "" { diff --git a/internal/daemon/controller/handlers/credentialstores/credentialstore_service_test.go b/internal/daemon/controller/handlers/credentialstores/credentialstore_service_test.go index 4331455119..8086b37c9d 100644 --- a/internal/daemon/controller/handlers/credentialstores/credentialstore_service_test.go +++ b/internal/daemon/controller/handlers/credentialstores/credentialstore_service_test.go @@ -1035,7 +1035,7 @@ func TestUpdateVault(t *testing.T) { { name: "update connection info", req: &pbs.UpdateCredentialStoreRequest{ - UpdateMask: fieldmask("attributes.address", "attributes.client_certificate", "attributes.client_certificate_key", "attributes.ca_cert", "attributes.token"), + UpdateMask: fieldmask(globals.AttributesAddressField, "attributes.client_certificate", "attributes.client_certificate_key", "attributes.ca_cert", "attributes.token"), Item: &pb.CredentialStore{ Attrs: &pb.CredentialStore_VaultCredentialStoreAttributes{ VaultCredentialStoreAttributes: &pb.VaultCredentialStoreAttributes{ diff --git a/internal/daemon/controller/handlers/hosts/host_service.go b/internal/daemon/controller/handlers/hosts/host_service.go index 1a5bf70f54..89a339cd86 100644 --- a/internal/daemon/controller/handlers/hosts/host_service.go +++ b/internal/daemon/controller/handlers/hosts/host_service.go @@ -637,21 +637,22 @@ func validateCreateRequest(req *pbs.CreateHostRequest) error { attrs := req.GetItem().GetStaticHostAttributes() switch { case attrs == nil: - badFields["attributes"] = "This is a required field." + badFields[globals.AttributesField] = "This is a required field." default: if attrs.GetAddress() == nil || len(attrs.GetAddress().GetValue()) < static.MinHostAddressLength || len(attrs.GetAddress().GetValue()) > static.MaxHostAddressLength { - badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) - } - _, _, err := net.SplitHostPort(attrs.GetAddress().GetValue()) - switch { - case err == nil: - badFields["attributes.address"] = "Address for static hosts does not support a port." - case strings.Contains(err.Error(), "missing port in address"): - // Bare hostname, which we want - default: - badFields["attributes.address"] = fmt.Sprintf("Error parsing address: %v.", err) + badFields[globals.AttributesAddressField] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) + } else { + _, _, err := net.SplitHostPort(attrs.GetAddress().GetValue()) + switch { + case err == nil: + badFields[globals.AttributesAddressField] = "Address for static hosts does not support a port." + case strings.Contains(err.Error(), globals.MissingPortErrStr): + // Bare hostname, which we want + default: + badFields[globals.AttributesAddressField] = fmt.Sprintf("Error parsing address: %v.", err) + } } } case plugin.Subtype: @@ -668,17 +669,26 @@ func validateUpdateRequest(req *pbs.UpdateHostRequest) error { case static.Subtype: if req.GetItem().GetType() != "" && req.GetItem().GetType() != static.Subtype.String() { badFields[globals.TypeField] = "Cannot modify the resource type." + } + if handlers.MaskContains(req.GetUpdateMask().GetPaths(), globals.AttributesAddressField) { attrs := req.GetItem().GetStaticHostAttributes() - - if handlers.MaskContains(req.GetUpdateMask().GetPaths(), "attributes.address") { - switch { - case attrs == nil: - badFields["attributes"] = "Attributes field not supplied request" - default: - if attrs.GetAddress() == nil || - len(strings.TrimSpace(attrs.GetAddress().GetValue())) < static.MinHostAddressLength || - len(strings.TrimSpace(attrs.GetAddress().GetValue())) > static.MaxHostAddressLength { - badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) + switch { + case attrs == nil: + badFields[globals.AttributesField] = "Attributes field not supplied in request" + default: + if attrs.GetAddress() == nil || + len(strings.TrimSpace(attrs.GetAddress().GetValue())) < static.MinHostAddressLength || + len(strings.TrimSpace(attrs.GetAddress().GetValue())) > static.MaxHostAddressLength { + badFields[globals.AttributesAddressField] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) + } else { + _, _, err := net.SplitHostPort(attrs.GetAddress().GetValue()) + switch { + case err == nil: + badFields[globals.AttributesAddressField] = "Address for static hosts does not support a port." + case strings.Contains(err.Error(), globals.MissingPortErrStr): + // Bare hostname, which we want + default: + badFields[globals.AttributesAddressField] = fmt.Sprintf("Error parsing address: %v.", err) } } } diff --git a/internal/daemon/controller/handlers/hosts/host_service_test.go b/internal/daemon/controller/handlers/hosts/host_service_test.go index 1a5acd3c08..6ea534a523 100644 --- a/internal/daemon/controller/handlers/hosts/host_service_test.go +++ b/internal/daemon/controller/handlers/hosts/host_service_test.go @@ -771,6 +771,21 @@ func TestCreate(t *testing.T) { res: nil, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Can't specify port", + req: &pbs.CreateHostRequest{Item: &pb.Host{ + HostCatalogId: hc.GetPublicId(), + Name: &wrappers.StringValue{Value: "port name"}, + Description: &wrappers.StringValue{Value: "port desc"}, + Attrs: &pb.Host_StaticHostAttributes{ + StaticHostAttributes: &pb.StaticHostAttributes{ + Address: wrapperspb.String("123.456.789:12345"), + }, + }, + }}, + res: nil, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, { name: "Can't specify Created Time", req: &pbs.CreateHostRequest{Item: &pb.Host{ @@ -798,11 +813,16 @@ func TestCreate(t *testing.T) { got, gErr := s.CreateHost(auth.DisabledAuthTestContext(iamRepoFn, proj.GetPublicId()), tc.req) if tc.err != nil { require.Error(gErr) + assert.Nil(got) assert.True(errors.Is(gErr, tc.err), "CreateHost(%+v) got error %v, wanted %v", tc.req, gErr, tc.err) if tc.wantErrContains != "" { assert.Contains(gErr.Error(), tc.wantErrContains) } + return } + require.NoError(gErr) + require.NotNil(got) + if got != nil { assert.Contains(got.GetUri(), tc.res.GetUri()) assert.True(strings.HasPrefix(got.GetItem().GetId(), static.HostPrefix)) @@ -868,22 +888,21 @@ func TestUpdate_Static(t *testing.T) { } hCreated := h.GetCreateTime().GetTimestamp().AsTime() - toMerge := &pbs.UpdateHostRequest{ - Id: h.GetPublicId(), - } tested, err := hosts.NewService(repoFn, pluginRepoFn) require.NoError(t, err, "Failed to create a new host set service.") cases := []struct { - name string - req *pbs.UpdateHostRequest - res *pbs.UpdateHostResponse - err error + name string + req *pbs.UpdateHostRequest + res *pbs.UpdateHostResponse + err error + wantErrContains string }{ { name: "Update an Existing Host", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"name", "description", "type"}, }, @@ -915,6 +934,7 @@ func TestUpdate_Static(t *testing.T) { { name: "Multiple Paths in single string", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"name,description,type"}, }, @@ -946,16 +966,19 @@ func TestUpdate_Static(t *testing.T) { { name: "No Update Mask", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), Item: &pb.Host{ Name: &wrappers.StringValue{Value: "updated name"}, Description: &wrappers.StringValue{Value: "updated desc"}, }, }, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "UpdateMask not provided", }, { - name: "No Update Mask", + name: "Changing Type", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"name,type"}, }, @@ -964,33 +987,39 @@ func TestUpdate_Static(t *testing.T) { Type: "ec2", }, }, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "Cannot modify the resource type", }, { name: "Empty Path", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{Paths: []string{}}, Item: &pb.Host{ Name: &wrappers.StringValue{Value: "updated name"}, Description: &wrappers.StringValue{Value: "updated desc"}, }, }, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "No valid fields provided", }, { - name: "Only non-existant paths in Mask", + name: "Only non-existent paths in Mask", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{Paths: []string{"nonexistant_field"}}, Item: &pb.Host{ Name: &wrappers.StringValue{Value: "updated name"}, Description: &wrappers.StringValue{Value: "updated desc"}, }, }, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "No valid fields provided", }, { name: "Unset Name", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"name"}, }, @@ -1019,6 +1048,7 @@ func TestUpdate_Static(t *testing.T) { { name: "Unset Description", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"description"}, }, @@ -1047,6 +1077,7 @@ func TestUpdate_Static(t *testing.T) { { name: "Update Only Name", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"name"}, }, @@ -1077,6 +1108,7 @@ func TestUpdate_Static(t *testing.T) { { name: "Update Only Description", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"description"}, }, @@ -1117,12 +1149,13 @@ func TestUpdate_Static(t *testing.T) { Description: &wrappers.StringValue{Value: "desc"}, }, }, - err: handlers.ApiErrorWithCode(codes.NotFound), + err: handlers.ApiErrorWithCode(codes.NotFound), + wantErrContains: "Resource not found", }, { name: "Cant change Id", req: &pbs.UpdateHostRequest{ - Id: hc.GetPublicId(), + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"id"}, }, @@ -1133,15 +1166,15 @@ func TestUpdate_Static(t *testing.T) { Description: &wrappers.StringValue{Value: "new desc"}, }, }, - res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "This is a read only field", }, { name: "Cant unset address", req: &pbs.UpdateHostRequest{ - Id: hc.GetPublicId(), + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ - Paths: []string{"attributes.address"}, + Paths: []string{globals.AttributesAddressField}, }, Item: &pb.Host{ Attrs: &pb.Host_StaticHostAttributes{ @@ -1151,15 +1184,15 @@ func TestUpdate_Static(t *testing.T) { }, }, }, - res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "Address length must be", }, { name: "Cant set address to empty string", req: &pbs.UpdateHostRequest{ - Id: hc.GetPublicId(), + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ - Paths: []string{"attributes.address"}, + Paths: []string{globals.AttributesAddressField}, }, Item: &pb.Host{ Attrs: &pb.Host_StaticHostAttributes{ @@ -1169,12 +1202,33 @@ func TestUpdate_Static(t *testing.T) { }, }, }, - res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "Address length must be", + }, + { + name: "Cant specify port in address", + req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), + UpdateMask: &field_mask.FieldMask{ + Paths: []string{globals.AttributesAddressField}, + }, + Item: &pb.Host{ + Name: &wrappers.StringValue{Value: "port name"}, + Description: &wrappers.StringValue{Value: "port desc"}, + Attrs: &pb.Host_StaticHostAttributes{ + StaticHostAttributes: &pb.StaticHostAttributes{ + Address: wrapperspb.String("123.456.789:12345"), + }, + }, + }, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "does not support a port", }, { name: "Cant specify Created Time", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"created_time"}, }, @@ -1182,12 +1236,13 @@ func TestUpdate_Static(t *testing.T) { CreatedTime: timestamppb.Now(), }, }, - res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "This is a read only field", }, { name: "Cant specify Updated Time", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"updated_time"}, }, @@ -1195,12 +1250,13 @@ func TestUpdate_Static(t *testing.T) { UpdatedTime: timestamppb.Now(), }, }, - res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "This is a read only field", }, { name: "Valid mask, cant specify type", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"name"}, }, @@ -1208,8 +1264,8 @@ func TestUpdate_Static(t *testing.T) { Type: "Unknown", }, }, - res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "Cannot modify the resource type", }, } for _, tc := range cases { @@ -1217,8 +1273,7 @@ func TestUpdate_Static(t *testing.T) { assert, require := assert.New(t), require.New(t) tc.req.Item.Version = version - req := proto.Clone(toMerge).(*pbs.UpdateHostRequest) - proto.Merge(req, tc.req) + req := tc.req // Test some bad versions req.Item.Version = version + 2 @@ -1232,26 +1287,26 @@ func TestUpdate_Static(t *testing.T) { got, gErr := tested.UpdateHost(auth.DisabledAuthTestContext(iamRepoFn, proj.GetPublicId()), req) if tc.err != nil { require.Error(gErr) + assert.Nil(got) assert.True(errors.Is(gErr, tc.err), "UpdateHost(%+v) got error %v, wanted %v", req, gErr, tc.err) + if tc.wantErrContains != "" { + assert.Contains(gErr.Error(), tc.wantErrContains) + } + return } - if tc.err == nil { - defer resetHost() - } + defer resetHost() + require.NoError(gErr) + require.NotNil(got) - if got != nil { - assert.NotNilf(tc.res, "Expected UpdateHost response to be nil, but was %v", got) - gotUpdateTime := got.GetItem().GetUpdatedTime().AsTime() - // Verify it is a set updated after it was created - // TODO: This is currently failing. - assert.True(gotUpdateTime.After(hCreated), "Updated set should have been updated after it's creation. Was updated %v, which is after %v", gotUpdateTime, hCreated) + require.NotNilf(tc.res, "Expected UpdateHost response to be nil, but was %v", got) + gotUpdateTime := got.GetItem().GetUpdatedTime().AsTime() + // Verify it is a set updated after it was created + assert.True(gotUpdateTime.After(hCreated), "Updated set should have been updated after its creation. Was updated %v, which is after %v", gotUpdateTime, hCreated) - // Clear all values which are hard to compare against. - got.Item.UpdatedTime, tc.res.Item.UpdatedTime = nil, nil - } - if tc.res != nil { - tc.res.Item.Version = version + 1 - } + // Clear all values which are hard to compare against. + got.Item.UpdatedTime, tc.res.Item.UpdatedTime = nil, nil + tc.res.Item.Version = version + 1 assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "UpdateHost(%q) got response %q, wanted %q", req, got, tc.res) }) } diff --git a/internal/daemon/controller/handlers/targets/target_service.go b/internal/daemon/controller/handlers/targets/target_service.go index 7519b4137d..0d18ee19f7 100644 --- a/internal/daemon/controller/handlers/targets/target_service.go +++ b/internal/daemon/controller/handlers/targets/target_service.go @@ -47,9 +47,8 @@ import ( ) const ( - credentialDomain = "credential" - hostDomain = "host" - missingPortErrStr = "missing port in address" + credentialDomain = "credential" + hostDomain = "host" ) // extraWorkerFilterFunc takes in a set of workers and returns another set, @@ -699,6 +698,10 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession return nil, handlers.ForbiddenError() } + if t.GetDefaultPort() == 0 { + return nil, handlers.ConflictErrorf("Target does not have default port defined.") + } + // Get the target information repo, err := s.repoFn() if err != nil { @@ -745,87 +748,102 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession selectedWorkers[i], selectedWorkers[j] = selectedWorkers[j], selectedWorkers[i] }) - requestedId := req.GetHostId() - staticHostRepo, err := s.staticHostRepoFn() - if err != nil { - return nil, err - } - pluginHostRepo, err := s.pluginHostRepoFn() - if err != nil { - return nil, err - } + p := strconv.FormatUint(uint64(t.GetDefaultPort()), 10) + var h, hostId, hostSetId string + + switch { + case t.GetAddress() != "": + h = t.GetAddress() + + default: + requestedId := req.GetHostId() + staticHostRepo, err := s.staticHostRepoFn() + if err != nil { + return nil, err + } + pluginHostRepo, err := s.pluginHostRepoFn() + if err != nil { + return nil, err + } - var pluginHostSetIds []string - var endpoints []*host.Endpoint - for _, hSource := range hostSources { - hsId := hSource.Id() - // FIXME: read in type from DB rather than rely on prefix - switch subtypes.SubtypeFromId(hostDomain, hsId) { - case static.Subtype: - eps, err := staticHostRepo.Endpoints(ctx, hsId) + var pluginHostSetIds []string + var endpoints []*host.Endpoint + for _, hSource := range hostSources { + hsId := hSource.Id() + switch subtypes.SubtypeFromId(hostDomain, hsId) { + case static.Subtype: + eps, err := staticHostRepo.Endpoints(ctx, hsId) + if err != nil { + return nil, err + } + endpoints = append(endpoints, eps...) + default: + // Batch the plugin host set ids since each round trip to the plugin + // has the potential to be expensive. + pluginHostSetIds = append(pluginHostSetIds, hsId) + } + } + if len(pluginHostSetIds) > 0 { + eps, err := pluginHostRepo.Endpoints(ctx, pluginHostSetIds) if err != nil { return nil, err } endpoints = append(endpoints, eps...) - default: - // Batch the plugin host set ids since each round trip to the plugin - // has the potential to be expensive. - pluginHostSetIds = append(pluginHostSetIds, hsId) - } - } - if len(pluginHostSetIds) > 0 { - eps, err := pluginHostRepo.Endpoints(ctx, pluginHostSetIds) - if err != nil { - return nil, err } - endpoints = append(endpoints, eps...) - } - if len(endpoints) == 0 && t.GetAddress() == "" { - return nil, handlers.NotFoundErrorf("No host sources or address found for given target.") - } + if len(endpoints) == 0 { + return nil, handlers.NotFoundErrorf("No host sources or address found for given target.") + } - var chosenEndpoint *host.Endpoint - if requestedId != "" { - for _, ep := range endpoints { - if ep.HostId == requestedId { - chosenEndpoint = ep + var chosenEndpoint *host.Endpoint + if requestedId != "" { + for _, ep := range endpoints { + if ep.HostId == requestedId { + chosenEndpoint = ep + } + } + if chosenEndpoint == nil { + // We didn't find it + return nil, handlers.InvalidArgumentErrorf( + "Errors in provided fields.", + map[string]string{ + "host_id": "The requested host id is not available.", + }) } } + if chosenEndpoint == nil { - // We didn't find it - return nil, handlers.InvalidArgumentErrorf( - "Errors in provided fields.", - map[string]string{ - "host_id": "The requested host id is not available.", - }) + chosenEndpoint = endpoints[rand.Intn(len(endpoints))] } - } - - var h, p, hostId, hostSetId string - if chosenEndpoint == nil && len(endpoints) > 0 { - chosenEndpoint = endpoints[rand.Intn(len(endpoints))] - } - if chosenEndpoint != nil { hostId = chosenEndpoint.HostId hostSetId = chosenEndpoint.SetId - h, p, err = net.SplitHostPort(chosenEndpoint.Address) - switch { - case err != nil && strings.Contains(err.Error(), missingPortErrStr): - if t.GetDefaultPort() == 0 { - return nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("neither the selected host %q nor the target provides a port to use", chosenEndpoint.HostId)) - } - h = chosenEndpoint.Address - p = strconv.FormatUint(uint64(t.GetDefaultPort()), 10) - case err != nil: - return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error when parsing the chosen endpoints host address")) - } + h = chosenEndpoint.Address } - if t.GetAddress() != "" && hostId == "" && hostSetId == "" { - h = t.GetAddress() - p = strconv.FormatUint(uint64(t.GetDefaultPort()), 10) + if h == "" { + return nil, handlers.ApiErrorWithCodeAndMessage( + codes.FailedPrecondition, + "No host was discovered after checking target address and host sources.") + } + + // Ensure we don't have a port from the address, which would be unexpected + // FIXME: We've decided to hold off on making this an error until 0.14. In + // the meantime, ignore any port coming from the host address. + hostWithoutPort, _, err := net.SplitHostPort(h) + switch { + case err != nil && strings.Contains(err.Error(), globals.MissingPortErrStr): + // This is what we expect + case err != nil: + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error when parsing the chosen endpoint host address")) + case err == nil: + h = hostWithoutPort + // Use below logic for 0.14: + /* + return nil, handlers.ApiErrorWithCodeAndMessage( + codes.FailedPrecondition, + "Address specified for use unexpectedly contains a port.") + */ } // Generate the endpoint URL @@ -1566,7 +1584,7 @@ func validateCreateRequest(req *pbs.CreateTargetRequest) error { switch { case err == nil: badFields[globals.AddressField] = "Address does not support a port." - case strings.Contains(err.Error(), missingPortErrStr): + case strings.Contains(err.Error(), globals.MissingPortErrStr): default: badFields[globals.AddressField] = fmt.Sprintf("Error parsing address: %v.", err) } @@ -1642,7 +1660,7 @@ func validateUpdateRequest(req *pbs.UpdateTargetRequest) error { switch { case err == nil: badFields[globals.AddressField] = "Address does not support a port." - case strings.Contains(err.Error(), missingPortErrStr): + case strings.Contains(err.Error(), globals.MissingPortErrStr): default: badFields[globals.AddressField] = fmt.Sprintf("Error parsing address: %v.", err) } diff --git a/internal/daemon/controller/handlers/targets/tcp/target_service_test.go b/internal/daemon/controller/handlers/targets/tcp/target_service_test.go index fc7fa2ba06..d48fe1287c 100644 --- a/internal/daemon/controller/handlers/targets/tcp/target_service_test.go +++ b/internal/daemon/controller/handlers/targets/tcp/target_service_test.go @@ -2494,6 +2494,7 @@ func TestAuthorizeSession(t *testing.T) { hWithPort := static.TestHosts(t, conn, hcWithPort.GetPublicId(), 1)[0] shsWithPort := static.TestSets(t, conn, hcWithPort.GetPublicId(), 1)[0] _ = static.TestSetMembers(t, conn, shsWithPort.GetPublicId(), []*static.Host{hWithPort}) + hWithPortBareAddress := hWithPort.GetAddress() hWithPort.Address = fmt.Sprintf("%s:54321", hWithPort.GetAddress()) hWithPort, _, err = staticRepo.UpdateHost(ctx, hcWithPort.GetProjectId(), hWithPort, hWithPort.GetVersion(), []string{"address"}) require.NoError(t, err) @@ -2551,7 +2552,7 @@ func TestAuthorizeSession(t *testing.T) { hostSourceId: shsWithPort.GetPublicId(), credSourceId: clsResp.GetItem().GetId(), wantedHostId: hWithPort.GetPublicId(), - wantedEndpoint: hWithPort.GetAddress(), + wantedEndpoint: fmt.Sprintf("%s:%d", hWithPortBareAddress, defaultPort), }, { name: "plugin host", @@ -3473,9 +3474,8 @@ func TestAuthorizeSession_Errors(t *testing.T) { setup: []func(tcpTarget target.Target) uint32{workerExists, hostExists, libraryExists}, }, { - name: "no port", + name: "no host port", setup: []func(tcpTarget target.Target) uint32{workerExists, hostWithoutPort, libraryExists}, - err: true, }, { name: "no hosts", @@ -3495,7 +3495,7 @@ func TestAuthorizeSession_Errors(t *testing.T) { } for i, tc := range cases { t.Run(tc.name, func(t *testing.T) { - tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), fmt.Sprintf("test-%d", i)) + tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), fmt.Sprintf("test-%d", i), target.WithDefaultPort(22)) for _, fn := range tc.setup { ver := fn(tar) diff --git a/internal/daemon/controller/handlers/workers/worker_service.go b/internal/daemon/controller/handlers/workers/worker_service.go index ba9f653f70..289a281d1b 100644 --- a/internal/daemon/controller/handlers/workers/worker_service.go +++ b/internal/daemon/controller/handlers/workers/worker_service.go @@ -774,7 +774,7 @@ func (s Service) toProto(ctx context.Context, in *server.Worker, opt ...handlers if outputFields.Has(globals.ScopeIdField) { out.ScopeId = in.GetScopeId() } - if outputFields.Has(globals.DirectlyConnectedDownstreamWorkers) { + if outputFields.Has(globals.DirectlyConnectedDownstreamWorkersField) { out.DirectlyConnectedDownstreamWorkers = downstreamWorkers(ctx, in.GetPublicId(), s.downstreams) } if outputFields.Has(globals.DescriptionField) && in.GetDescription() != "" { diff --git a/internal/daemon/worker/controller_connection.go b/internal/daemon/worker/controller_connection.go index 30e331d19b..d044b36f63 100644 --- a/internal/daemon/worker/controller_connection.go +++ b/internal/daemon/worker/controller_connection.go @@ -49,7 +49,7 @@ func (w *Worker) StartControllerConnections() error { initialAddrs = append(initialAddrs, addr) default: host, port, err := net.SplitHostPort(addr) - if err != nil && strings.Contains(err.Error(), "missing port in address") { + if err != nil && strings.Contains(err.Error(), globals.MissingPortErrStr) { host, port, err = net.SplitHostPort(net.JoinHostPort(addr, "9201")) } if err != nil {