diff --git a/CHANGELOG.md b/CHANGELOG.md index 72753d6b0d..925e0d18d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,21 @@ Canonical reference for changes, improvements, and bugfixes for Boundary. ## 0.12.0 (2023/01/24) +### Deprecations/Changes + +* In Boundary 0.9.0, targets were updated to require a default port value. This + had been the original intention; it was a mistake that it was optional. + Unfortunately, due to a separate defect in the update verification logic for + static hosts, it was possible for a host to be updated (but not created) with + a port. This meant that targets could use ports attached to host addresses, + which was not the intention and leads to confusing behavior across different + installations. In this version, updating static hosts will no longer allow + ports to be part of the address; when authorizing a session, any port on such + a host will be ignored in favor of the default port on the target. In Boundary + 0.14.0, this will become an error instead. As a consequence, it means that the + fallback logic for targets that did not have a default port defined is no + longer in service; all targets must now have a default port defined. + ### New and Improved * Direct Address Targets: You can now set an address directly on a target, diff --git a/globals/errors.go b/globals/errors.go new file mode 100644 index 0000000000..a57c02a0b6 --- /dev/null +++ b/globals/errors.go @@ -0,0 +1,3 @@ +package globals + +const MissingPortErrStr = "missing port in address" diff --git a/globals/fields.go b/globals/fields.go index 63ad052300..acff59df20 100644 --- a/globals/fields.go +++ b/globals/fields.go @@ -93,5 +93,6 @@ const ( KeyVersionIdField = "key_version_id" CompletedCountField = "completed_count" TotalCountField = "total_count" - DirectlyConnectedDownstreamWorkers = "directly_connected_downstream_workers" + DirectlyConnectedDownstreamWorkersField = "directly_connected_downstream_workers" + AttributesAddressField = "attributes.address" ) diff --git a/internal/cmd/commands/server/server.go b/internal/cmd/commands/server/server.go index a088c89933..98c0a8c013 100644 --- a/internal/cmd/commands/server/server.go +++ b/internal/cmd/commands/server/server.go @@ -311,7 +311,7 @@ func (c *Command) Run(args []string) int { for _, upstream := range c.Config.Worker.InitialUpstreams { host, _, err := net.SplitHostPort(upstream) if err != nil { - if strings.Contains(err.Error(), "missing port in address") { + if strings.Contains(err.Error(), globals.MissingPortErrStr) { host = upstream } else { c.UI.Error(fmt.Errorf("Invalid worker upstream address %q: %w", upstream, err).Error()) @@ -355,7 +355,7 @@ func (c *Command) Run(args []string) int { } host, _, err := net.SplitHostPort(ln.Address) if err != nil { - if strings.Contains(err.Error(), "missing port in address") { + if strings.Contains(err.Error(), globals.MissingPortErrStr) { host = ln.Address } else { c.UI.Error(fmt.Errorf("Invalid cluster listener address %q: %w", ln.Address, err).Error()) diff --git a/internal/cmd/config/config.go b/internal/cmd/config/config.go index 40f434d74c..224a7226d5 100644 --- a/internal/cmd/config/config.go +++ b/internal/cmd/config/config.go @@ -1103,7 +1103,7 @@ func (c *Config) SetupWorkerInitialUpstreams() error { } // Best effort see if it's a domain name and if not assume it must match host, _, err := net.SplitHostPort(c.Worker.InitialUpstreams[0]) - if err != nil && strings.Contains(err.Error(), "missing port in address") { + if err != nil && strings.Contains(err.Error(), globals.MissingPortErrStr) { err = nil host = c.Worker.InitialUpstreams[0] } diff --git a/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go b/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go index 6b1c1f7c55..c0f91eccfa 100644 --- a/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go +++ b/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go @@ -32,7 +32,6 @@ import ( ) const ( - addressField = "attributes.address" vaultTokenField = "attributes.token" vaultTokenHmacField = "attributes.token_hmac" vaultWorkerFilterField = "attributes.worker_filter" @@ -793,7 +792,7 @@ func validateCreateRequest(ctx context.Context, req *pbs.CreateCredentialStoreRe } if attrs.GetAddress().GetValue() == "" { - badFields[addressField] = "Field required for creating a vault credential store." + badFields[globals.AttributesAddressField] = "Field required for creating a vault credential store." } if attrs.GetToken().GetValue() == "" { badFields[vaultTokenField] = "Field required for creating a vault credential store." @@ -842,9 +841,9 @@ func validateUpdateRequest(ctx context.Context, req *pbs.UpdateCredentialStoreRe } attrs := req.GetItem().GetVaultCredentialStoreAttributes() if attrs != nil { - if handlers.MaskContains(req.GetUpdateMask().GetPaths(), addressField) && + if handlers.MaskContains(req.GetUpdateMask().GetPaths(), globals.AttributesAddressField) && attrs.GetAddress().GetValue() == "" { - badFields[addressField] = "This is a required field and cannot be unset." + badFields[globals.AttributesAddressField] = "This is a required field and cannot be unset." } if handlers.MaskContains(req.GetUpdateMask().GetPaths(), vaultTokenField) && attrs.GetToken().GetValue() == "" { diff --git a/internal/daemon/controller/handlers/credentialstores/credentialstore_service_test.go b/internal/daemon/controller/handlers/credentialstores/credentialstore_service_test.go index 4331455119..8086b37c9d 100644 --- a/internal/daemon/controller/handlers/credentialstores/credentialstore_service_test.go +++ b/internal/daemon/controller/handlers/credentialstores/credentialstore_service_test.go @@ -1035,7 +1035,7 @@ func TestUpdateVault(t *testing.T) { { name: "update connection info", req: &pbs.UpdateCredentialStoreRequest{ - UpdateMask: fieldmask("attributes.address", "attributes.client_certificate", "attributes.client_certificate_key", "attributes.ca_cert", "attributes.token"), + UpdateMask: fieldmask(globals.AttributesAddressField, "attributes.client_certificate", "attributes.client_certificate_key", "attributes.ca_cert", "attributes.token"), Item: &pb.CredentialStore{ Attrs: &pb.CredentialStore_VaultCredentialStoreAttributes{ VaultCredentialStoreAttributes: &pb.VaultCredentialStoreAttributes{ diff --git a/internal/daemon/controller/handlers/hosts/host_service.go b/internal/daemon/controller/handlers/hosts/host_service.go index 1a5bf70f54..89a339cd86 100644 --- a/internal/daemon/controller/handlers/hosts/host_service.go +++ b/internal/daemon/controller/handlers/hosts/host_service.go @@ -637,21 +637,22 @@ func validateCreateRequest(req *pbs.CreateHostRequest) error { attrs := req.GetItem().GetStaticHostAttributes() switch { case attrs == nil: - badFields["attributes"] = "This is a required field." + badFields[globals.AttributesField] = "This is a required field." default: if attrs.GetAddress() == nil || len(attrs.GetAddress().GetValue()) < static.MinHostAddressLength || len(attrs.GetAddress().GetValue()) > static.MaxHostAddressLength { - badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) - } - _, _, err := net.SplitHostPort(attrs.GetAddress().GetValue()) - switch { - case err == nil: - badFields["attributes.address"] = "Address for static hosts does not support a port." - case strings.Contains(err.Error(), "missing port in address"): - // Bare hostname, which we want - default: - badFields["attributes.address"] = fmt.Sprintf("Error parsing address: %v.", err) + badFields[globals.AttributesAddressField] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) + } else { + _, _, err := net.SplitHostPort(attrs.GetAddress().GetValue()) + switch { + case err == nil: + badFields[globals.AttributesAddressField] = "Address for static hosts does not support a port." + case strings.Contains(err.Error(), globals.MissingPortErrStr): + // Bare hostname, which we want + default: + badFields[globals.AttributesAddressField] = fmt.Sprintf("Error parsing address: %v.", err) + } } } case plugin.Subtype: @@ -668,17 +669,26 @@ func validateUpdateRequest(req *pbs.UpdateHostRequest) error { case static.Subtype: if req.GetItem().GetType() != "" && req.GetItem().GetType() != static.Subtype.String() { badFields[globals.TypeField] = "Cannot modify the resource type." + } + if handlers.MaskContains(req.GetUpdateMask().GetPaths(), globals.AttributesAddressField) { attrs := req.GetItem().GetStaticHostAttributes() - - if handlers.MaskContains(req.GetUpdateMask().GetPaths(), "attributes.address") { - switch { - case attrs == nil: - badFields["attributes"] = "Attributes field not supplied request" - default: - if attrs.GetAddress() == nil || - len(strings.TrimSpace(attrs.GetAddress().GetValue())) < static.MinHostAddressLength || - len(strings.TrimSpace(attrs.GetAddress().GetValue())) > static.MaxHostAddressLength { - badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) + switch { + case attrs == nil: + badFields[globals.AttributesField] = "Attributes field not supplied in request" + default: + if attrs.GetAddress() == nil || + len(strings.TrimSpace(attrs.GetAddress().GetValue())) < static.MinHostAddressLength || + len(strings.TrimSpace(attrs.GetAddress().GetValue())) > static.MaxHostAddressLength { + badFields[globals.AttributesAddressField] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) + } else { + _, _, err := net.SplitHostPort(attrs.GetAddress().GetValue()) + switch { + case err == nil: + badFields[globals.AttributesAddressField] = "Address for static hosts does not support a port." + case strings.Contains(err.Error(), globals.MissingPortErrStr): + // Bare hostname, which we want + default: + badFields[globals.AttributesAddressField] = fmt.Sprintf("Error parsing address: %v.", err) } } } diff --git a/internal/daemon/controller/handlers/hosts/host_service_test.go b/internal/daemon/controller/handlers/hosts/host_service_test.go index 1a5acd3c08..6ea534a523 100644 --- a/internal/daemon/controller/handlers/hosts/host_service_test.go +++ b/internal/daemon/controller/handlers/hosts/host_service_test.go @@ -771,6 +771,21 @@ func TestCreate(t *testing.T) { res: nil, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Can't specify port", + req: &pbs.CreateHostRequest{Item: &pb.Host{ + HostCatalogId: hc.GetPublicId(), + Name: &wrappers.StringValue{Value: "port name"}, + Description: &wrappers.StringValue{Value: "port desc"}, + Attrs: &pb.Host_StaticHostAttributes{ + StaticHostAttributes: &pb.StaticHostAttributes{ + Address: wrapperspb.String("123.456.789:12345"), + }, + }, + }}, + res: nil, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, { name: "Can't specify Created Time", req: &pbs.CreateHostRequest{Item: &pb.Host{ @@ -798,11 +813,16 @@ func TestCreate(t *testing.T) { got, gErr := s.CreateHost(auth.DisabledAuthTestContext(iamRepoFn, proj.GetPublicId()), tc.req) if tc.err != nil { require.Error(gErr) + assert.Nil(got) assert.True(errors.Is(gErr, tc.err), "CreateHost(%+v) got error %v, wanted %v", tc.req, gErr, tc.err) if tc.wantErrContains != "" { assert.Contains(gErr.Error(), tc.wantErrContains) } + return } + require.NoError(gErr) + require.NotNil(got) + if got != nil { assert.Contains(got.GetUri(), tc.res.GetUri()) assert.True(strings.HasPrefix(got.GetItem().GetId(), static.HostPrefix)) @@ -868,22 +888,21 @@ func TestUpdate_Static(t *testing.T) { } hCreated := h.GetCreateTime().GetTimestamp().AsTime() - toMerge := &pbs.UpdateHostRequest{ - Id: h.GetPublicId(), - } tested, err := hosts.NewService(repoFn, pluginRepoFn) require.NoError(t, err, "Failed to create a new host set service.") cases := []struct { - name string - req *pbs.UpdateHostRequest - res *pbs.UpdateHostResponse - err error + name string + req *pbs.UpdateHostRequest + res *pbs.UpdateHostResponse + err error + wantErrContains string }{ { name: "Update an Existing Host", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"name", "description", "type"}, }, @@ -915,6 +934,7 @@ func TestUpdate_Static(t *testing.T) { { name: "Multiple Paths in single string", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"name,description,type"}, }, @@ -946,16 +966,19 @@ func TestUpdate_Static(t *testing.T) { { name: "No Update Mask", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), Item: &pb.Host{ Name: &wrappers.StringValue{Value: "updated name"}, Description: &wrappers.StringValue{Value: "updated desc"}, }, }, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "UpdateMask not provided", }, { - name: "No Update Mask", + name: "Changing Type", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"name,type"}, }, @@ -964,33 +987,39 @@ func TestUpdate_Static(t *testing.T) { Type: "ec2", }, }, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "Cannot modify the resource type", }, { name: "Empty Path", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{Paths: []string{}}, Item: &pb.Host{ Name: &wrappers.StringValue{Value: "updated name"}, Description: &wrappers.StringValue{Value: "updated desc"}, }, }, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "No valid fields provided", }, { - name: "Only non-existant paths in Mask", + name: "Only non-existent paths in Mask", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{Paths: []string{"nonexistant_field"}}, Item: &pb.Host{ Name: &wrappers.StringValue{Value: "updated name"}, Description: &wrappers.StringValue{Value: "updated desc"}, }, }, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "No valid fields provided", }, { name: "Unset Name", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"name"}, }, @@ -1019,6 +1048,7 @@ func TestUpdate_Static(t *testing.T) { { name: "Unset Description", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"description"}, }, @@ -1047,6 +1077,7 @@ func TestUpdate_Static(t *testing.T) { { name: "Update Only Name", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"name"}, }, @@ -1077,6 +1108,7 @@ func TestUpdate_Static(t *testing.T) { { name: "Update Only Description", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"description"}, }, @@ -1117,12 +1149,13 @@ func TestUpdate_Static(t *testing.T) { Description: &wrappers.StringValue{Value: "desc"}, }, }, - err: handlers.ApiErrorWithCode(codes.NotFound), + err: handlers.ApiErrorWithCode(codes.NotFound), + wantErrContains: "Resource not found", }, { name: "Cant change Id", req: &pbs.UpdateHostRequest{ - Id: hc.GetPublicId(), + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"id"}, }, @@ -1133,15 +1166,15 @@ func TestUpdate_Static(t *testing.T) { Description: &wrappers.StringValue{Value: "new desc"}, }, }, - res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "This is a read only field", }, { name: "Cant unset address", req: &pbs.UpdateHostRequest{ - Id: hc.GetPublicId(), + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ - Paths: []string{"attributes.address"}, + Paths: []string{globals.AttributesAddressField}, }, Item: &pb.Host{ Attrs: &pb.Host_StaticHostAttributes{ @@ -1151,15 +1184,15 @@ func TestUpdate_Static(t *testing.T) { }, }, }, - res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "Address length must be", }, { name: "Cant set address to empty string", req: &pbs.UpdateHostRequest{ - Id: hc.GetPublicId(), + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ - Paths: []string{"attributes.address"}, + Paths: []string{globals.AttributesAddressField}, }, Item: &pb.Host{ Attrs: &pb.Host_StaticHostAttributes{ @@ -1169,12 +1202,33 @@ func TestUpdate_Static(t *testing.T) { }, }, }, - res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "Address length must be", + }, + { + name: "Cant specify port in address", + req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), + UpdateMask: &field_mask.FieldMask{ + Paths: []string{globals.AttributesAddressField}, + }, + Item: &pb.Host{ + Name: &wrappers.StringValue{Value: "port name"}, + Description: &wrappers.StringValue{Value: "port desc"}, + Attrs: &pb.Host_StaticHostAttributes{ + StaticHostAttributes: &pb.StaticHostAttributes{ + Address: wrapperspb.String("123.456.789:12345"), + }, + }, + }, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "does not support a port", }, { name: "Cant specify Created Time", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"created_time"}, }, @@ -1182,12 +1236,13 @@ func TestUpdate_Static(t *testing.T) { CreatedTime: timestamppb.Now(), }, }, - res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "This is a read only field", }, { name: "Cant specify Updated Time", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"updated_time"}, }, @@ -1195,12 +1250,13 @@ func TestUpdate_Static(t *testing.T) { UpdatedTime: timestamppb.Now(), }, }, - res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "This is a read only field", }, { name: "Valid mask, cant specify type", req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), UpdateMask: &field_mask.FieldMask{ Paths: []string{"name"}, }, @@ -1208,8 +1264,8 @@ func TestUpdate_Static(t *testing.T) { Type: "Unknown", }, }, - res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + wantErrContains: "Cannot modify the resource type", }, } for _, tc := range cases { @@ -1217,8 +1273,7 @@ func TestUpdate_Static(t *testing.T) { assert, require := assert.New(t), require.New(t) tc.req.Item.Version = version - req := proto.Clone(toMerge).(*pbs.UpdateHostRequest) - proto.Merge(req, tc.req) + req := tc.req // Test some bad versions req.Item.Version = version + 2 @@ -1232,26 +1287,26 @@ func TestUpdate_Static(t *testing.T) { got, gErr := tested.UpdateHost(auth.DisabledAuthTestContext(iamRepoFn, proj.GetPublicId()), req) if tc.err != nil { require.Error(gErr) + assert.Nil(got) assert.True(errors.Is(gErr, tc.err), "UpdateHost(%+v) got error %v, wanted %v", req, gErr, tc.err) + if tc.wantErrContains != "" { + assert.Contains(gErr.Error(), tc.wantErrContains) + } + return } - if tc.err == nil { - defer resetHost() - } + defer resetHost() + require.NoError(gErr) + require.NotNil(got) - if got != nil { - assert.NotNilf(tc.res, "Expected UpdateHost response to be nil, but was %v", got) - gotUpdateTime := got.GetItem().GetUpdatedTime().AsTime() - // Verify it is a set updated after it was created - // TODO: This is currently failing. - assert.True(gotUpdateTime.After(hCreated), "Updated set should have been updated after it's creation. Was updated %v, which is after %v", gotUpdateTime, hCreated) + require.NotNilf(tc.res, "Expected UpdateHost response to be nil, but was %v", got) + gotUpdateTime := got.GetItem().GetUpdatedTime().AsTime() + // Verify it is a set updated after it was created + assert.True(gotUpdateTime.After(hCreated), "Updated set should have been updated after its creation. Was updated %v, which is after %v", gotUpdateTime, hCreated) - // Clear all values which are hard to compare against. - got.Item.UpdatedTime, tc.res.Item.UpdatedTime = nil, nil - } - if tc.res != nil { - tc.res.Item.Version = version + 1 - } + // Clear all values which are hard to compare against. + got.Item.UpdatedTime, tc.res.Item.UpdatedTime = nil, nil + tc.res.Item.Version = version + 1 assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "UpdateHost(%q) got response %q, wanted %q", req, got, tc.res) }) } diff --git a/internal/daemon/controller/handlers/targets/target_service.go b/internal/daemon/controller/handlers/targets/target_service.go index 7519b4137d..0d18ee19f7 100644 --- a/internal/daemon/controller/handlers/targets/target_service.go +++ b/internal/daemon/controller/handlers/targets/target_service.go @@ -47,9 +47,8 @@ import ( ) const ( - credentialDomain = "credential" - hostDomain = "host" - missingPortErrStr = "missing port in address" + credentialDomain = "credential" + hostDomain = "host" ) // extraWorkerFilterFunc takes in a set of workers and returns another set, @@ -699,6 +698,10 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession return nil, handlers.ForbiddenError() } + if t.GetDefaultPort() == 0 { + return nil, handlers.ConflictErrorf("Target does not have default port defined.") + } + // Get the target information repo, err := s.repoFn() if err != nil { @@ -745,87 +748,102 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession selectedWorkers[i], selectedWorkers[j] = selectedWorkers[j], selectedWorkers[i] }) - requestedId := req.GetHostId() - staticHostRepo, err := s.staticHostRepoFn() - if err != nil { - return nil, err - } - pluginHostRepo, err := s.pluginHostRepoFn() - if err != nil { - return nil, err - } + p := strconv.FormatUint(uint64(t.GetDefaultPort()), 10) + var h, hostId, hostSetId string + + switch { + case t.GetAddress() != "": + h = t.GetAddress() + + default: + requestedId := req.GetHostId() + staticHostRepo, err := s.staticHostRepoFn() + if err != nil { + return nil, err + } + pluginHostRepo, err := s.pluginHostRepoFn() + if err != nil { + return nil, err + } - var pluginHostSetIds []string - var endpoints []*host.Endpoint - for _, hSource := range hostSources { - hsId := hSource.Id() - // FIXME: read in type from DB rather than rely on prefix - switch subtypes.SubtypeFromId(hostDomain, hsId) { - case static.Subtype: - eps, err := staticHostRepo.Endpoints(ctx, hsId) + var pluginHostSetIds []string + var endpoints []*host.Endpoint + for _, hSource := range hostSources { + hsId := hSource.Id() + switch subtypes.SubtypeFromId(hostDomain, hsId) { + case static.Subtype: + eps, err := staticHostRepo.Endpoints(ctx, hsId) + if err != nil { + return nil, err + } + endpoints = append(endpoints, eps...) + default: + // Batch the plugin host set ids since each round trip to the plugin + // has the potential to be expensive. + pluginHostSetIds = append(pluginHostSetIds, hsId) + } + } + if len(pluginHostSetIds) > 0 { + eps, err := pluginHostRepo.Endpoints(ctx, pluginHostSetIds) if err != nil { return nil, err } endpoints = append(endpoints, eps...) - default: - // Batch the plugin host set ids since each round trip to the plugin - // has the potential to be expensive. - pluginHostSetIds = append(pluginHostSetIds, hsId) - } - } - if len(pluginHostSetIds) > 0 { - eps, err := pluginHostRepo.Endpoints(ctx, pluginHostSetIds) - if err != nil { - return nil, err } - endpoints = append(endpoints, eps...) - } - if len(endpoints) == 0 && t.GetAddress() == "" { - return nil, handlers.NotFoundErrorf("No host sources or address found for given target.") - } + if len(endpoints) == 0 { + return nil, handlers.NotFoundErrorf("No host sources or address found for given target.") + } - var chosenEndpoint *host.Endpoint - if requestedId != "" { - for _, ep := range endpoints { - if ep.HostId == requestedId { - chosenEndpoint = ep + var chosenEndpoint *host.Endpoint + if requestedId != "" { + for _, ep := range endpoints { + if ep.HostId == requestedId { + chosenEndpoint = ep + } + } + if chosenEndpoint == nil { + // We didn't find it + return nil, handlers.InvalidArgumentErrorf( + "Errors in provided fields.", + map[string]string{ + "host_id": "The requested host id is not available.", + }) } } + if chosenEndpoint == nil { - // We didn't find it - return nil, handlers.InvalidArgumentErrorf( - "Errors in provided fields.", - map[string]string{ - "host_id": "The requested host id is not available.", - }) + chosenEndpoint = endpoints[rand.Intn(len(endpoints))] } - } - - var h, p, hostId, hostSetId string - if chosenEndpoint == nil && len(endpoints) > 0 { - chosenEndpoint = endpoints[rand.Intn(len(endpoints))] - } - if chosenEndpoint != nil { hostId = chosenEndpoint.HostId hostSetId = chosenEndpoint.SetId - h, p, err = net.SplitHostPort(chosenEndpoint.Address) - switch { - case err != nil && strings.Contains(err.Error(), missingPortErrStr): - if t.GetDefaultPort() == 0 { - return nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("neither the selected host %q nor the target provides a port to use", chosenEndpoint.HostId)) - } - h = chosenEndpoint.Address - p = strconv.FormatUint(uint64(t.GetDefaultPort()), 10) - case err != nil: - return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error when parsing the chosen endpoints host address")) - } + h = chosenEndpoint.Address } - if t.GetAddress() != "" && hostId == "" && hostSetId == "" { - h = t.GetAddress() - p = strconv.FormatUint(uint64(t.GetDefaultPort()), 10) + if h == "" { + return nil, handlers.ApiErrorWithCodeAndMessage( + codes.FailedPrecondition, + "No host was discovered after checking target address and host sources.") + } + + // Ensure we don't have a port from the address, which would be unexpected + // FIXME: We've decided to hold off on making this an error until 0.14. In + // the meantime, ignore any port coming from the host address. + hostWithoutPort, _, err := net.SplitHostPort(h) + switch { + case err != nil && strings.Contains(err.Error(), globals.MissingPortErrStr): + // This is what we expect + case err != nil: + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error when parsing the chosen endpoint host address")) + case err == nil: + h = hostWithoutPort + // Use below logic for 0.14: + /* + return nil, handlers.ApiErrorWithCodeAndMessage( + codes.FailedPrecondition, + "Address specified for use unexpectedly contains a port.") + */ } // Generate the endpoint URL @@ -1566,7 +1584,7 @@ func validateCreateRequest(req *pbs.CreateTargetRequest) error { switch { case err == nil: badFields[globals.AddressField] = "Address does not support a port." - case strings.Contains(err.Error(), missingPortErrStr): + case strings.Contains(err.Error(), globals.MissingPortErrStr): default: badFields[globals.AddressField] = fmt.Sprintf("Error parsing address: %v.", err) } @@ -1642,7 +1660,7 @@ func validateUpdateRequest(req *pbs.UpdateTargetRequest) error { switch { case err == nil: badFields[globals.AddressField] = "Address does not support a port." - case strings.Contains(err.Error(), missingPortErrStr): + case strings.Contains(err.Error(), globals.MissingPortErrStr): default: badFields[globals.AddressField] = fmt.Sprintf("Error parsing address: %v.", err) } diff --git a/internal/daemon/controller/handlers/targets/tcp/target_service_test.go b/internal/daemon/controller/handlers/targets/tcp/target_service_test.go index fc7fa2ba06..d48fe1287c 100644 --- a/internal/daemon/controller/handlers/targets/tcp/target_service_test.go +++ b/internal/daemon/controller/handlers/targets/tcp/target_service_test.go @@ -2494,6 +2494,7 @@ func TestAuthorizeSession(t *testing.T) { hWithPort := static.TestHosts(t, conn, hcWithPort.GetPublicId(), 1)[0] shsWithPort := static.TestSets(t, conn, hcWithPort.GetPublicId(), 1)[0] _ = static.TestSetMembers(t, conn, shsWithPort.GetPublicId(), []*static.Host{hWithPort}) + hWithPortBareAddress := hWithPort.GetAddress() hWithPort.Address = fmt.Sprintf("%s:54321", hWithPort.GetAddress()) hWithPort, _, err = staticRepo.UpdateHost(ctx, hcWithPort.GetProjectId(), hWithPort, hWithPort.GetVersion(), []string{"address"}) require.NoError(t, err) @@ -2551,7 +2552,7 @@ func TestAuthorizeSession(t *testing.T) { hostSourceId: shsWithPort.GetPublicId(), credSourceId: clsResp.GetItem().GetId(), wantedHostId: hWithPort.GetPublicId(), - wantedEndpoint: hWithPort.GetAddress(), + wantedEndpoint: fmt.Sprintf("%s:%d", hWithPortBareAddress, defaultPort), }, { name: "plugin host", @@ -3473,9 +3474,8 @@ func TestAuthorizeSession_Errors(t *testing.T) { setup: []func(tcpTarget target.Target) uint32{workerExists, hostExists, libraryExists}, }, { - name: "no port", + name: "no host port", setup: []func(tcpTarget target.Target) uint32{workerExists, hostWithoutPort, libraryExists}, - err: true, }, { name: "no hosts", @@ -3495,7 +3495,7 @@ func TestAuthorizeSession_Errors(t *testing.T) { } for i, tc := range cases { t.Run(tc.name, func(t *testing.T) { - tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), fmt.Sprintf("test-%d", i)) + tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), fmt.Sprintf("test-%d", i), target.WithDefaultPort(22)) for _, fn := range tc.setup { ver := fn(tar) diff --git a/internal/daemon/controller/handlers/workers/worker_service.go b/internal/daemon/controller/handlers/workers/worker_service.go index ba9f653f70..289a281d1b 100644 --- a/internal/daemon/controller/handlers/workers/worker_service.go +++ b/internal/daemon/controller/handlers/workers/worker_service.go @@ -774,7 +774,7 @@ func (s Service) toProto(ctx context.Context, in *server.Worker, opt ...handlers if outputFields.Has(globals.ScopeIdField) { out.ScopeId = in.GetScopeId() } - if outputFields.Has(globals.DirectlyConnectedDownstreamWorkers) { + if outputFields.Has(globals.DirectlyConnectedDownstreamWorkersField) { out.DirectlyConnectedDownstreamWorkers = downstreamWorkers(ctx, in.GetPublicId(), s.downstreams) } if outputFields.Has(globals.DescriptionField) && in.GetDescription() != "" { diff --git a/internal/daemon/worker/controller_connection.go b/internal/daemon/worker/controller_connection.go index 30e331d19b..d044b36f63 100644 --- a/internal/daemon/worker/controller_connection.go +++ b/internal/daemon/worker/controller_connection.go @@ -49,7 +49,7 @@ func (w *Worker) StartControllerConnections() error { initialAddrs = append(initialAddrs, addr) default: host, port, err := net.SplitHostPort(addr) - if err != nil && strings.Contains(err.Error(), "missing port in address") { + if err != nil && strings.Contains(err.Error(), globals.MissingPortErrStr) { host, port, err = net.SplitHostPort(net.JoinHostPort(addr, "9201")) } if err != nil {