Fix target port handling (#2846)

* Fix target port handling

This fixes two issues that compounded on each other (see the Changelog
update for more information):

* The verification logic for hosts was not correct for update operations
(in multiple ways) which meant that a host could be updated after
creation to have a port. Targets had a previously fixed bug where they
did not require a default port, which meant that ports could be used
from hosts
* A recent (unreleased) change had prioritized any port coming from the
host over the default port, which would mean there was no way if a
host had a port specified to use it in multiple targets.

This fixes the update verification logic, and strikes a middle ground
between breaking things and not by allowing existing addresses with
ports to be used with targets but ignoring that port, instead requiring
targets to have default port set at authorize time (currently there is
backwards compat that does not require this due to the original optional
port bug).

* Update CHANGELOG.md

Co-authored-by: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
pull/2862/head
Jeff Mitchell 3 years ago committed by GitHub
parent 5fd66b8d88
commit 19180af0eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,6 +4,21 @@ Canonical reference for changes, improvements, and bugfixes for Boundary.
## 0.12.0 (2023/01/24)
### Deprecations/Changes
* In Boundary 0.9.0, targets were updated to require a default port value. This
had been the original intention; it was a mistake that it was optional.
Unfortunately, due to a separate defect in the update verification logic for
static hosts, it was possible for a host to be updated (but not created) with
a port. This meant that targets could use ports attached to host addresses,
which was not the intention and leads to confusing behavior across different
installations. In this version, updating static hosts will no longer allow
ports to be part of the address; when authorizing a session, any port on such
a host will be ignored in favor of the default port on the target. In Boundary
0.14.0, this will become an error instead. As a consequence, it means that the
fallback logic for targets that did not have a default port defined is no
longer in service; all targets must now have a default port defined.
### New and Improved
* Direct Address Targets: You can now set an address directly on a target,

@ -0,0 +1,3 @@
package globals
const MissingPortErrStr = "missing port in address"

@ -93,5 +93,6 @@ const (
KeyVersionIdField = "key_version_id"
CompletedCountField = "completed_count"
TotalCountField = "total_count"
DirectlyConnectedDownstreamWorkers = "directly_connected_downstream_workers"
DirectlyConnectedDownstreamWorkersField = "directly_connected_downstream_workers"
AttributesAddressField = "attributes.address"
)

@ -311,7 +311,7 @@ func (c *Command) Run(args []string) int {
for _, upstream := range c.Config.Worker.InitialUpstreams {
host, _, err := net.SplitHostPort(upstream)
if err != nil {
if strings.Contains(err.Error(), "missing port in address") {
if strings.Contains(err.Error(), globals.MissingPortErrStr) {
host = upstream
} else {
c.UI.Error(fmt.Errorf("Invalid worker upstream address %q: %w", upstream, err).Error())
@ -355,7 +355,7 @@ func (c *Command) Run(args []string) int {
}
host, _, err := net.SplitHostPort(ln.Address)
if err != nil {
if strings.Contains(err.Error(), "missing port in address") {
if strings.Contains(err.Error(), globals.MissingPortErrStr) {
host = ln.Address
} else {
c.UI.Error(fmt.Errorf("Invalid cluster listener address %q: %w", ln.Address, err).Error())

@ -1103,7 +1103,7 @@ func (c *Config) SetupWorkerInitialUpstreams() error {
}
// Best effort see if it's a domain name and if not assume it must match
host, _, err := net.SplitHostPort(c.Worker.InitialUpstreams[0])
if err != nil && strings.Contains(err.Error(), "missing port in address") {
if err != nil && strings.Contains(err.Error(), globals.MissingPortErrStr) {
err = nil
host = c.Worker.InitialUpstreams[0]
}

@ -32,7 +32,6 @@ import (
)
const (
addressField = "attributes.address"
vaultTokenField = "attributes.token"
vaultTokenHmacField = "attributes.token_hmac"
vaultWorkerFilterField = "attributes.worker_filter"
@ -793,7 +792,7 @@ func validateCreateRequest(ctx context.Context, req *pbs.CreateCredentialStoreRe
}
if attrs.GetAddress().GetValue() == "" {
badFields[addressField] = "Field required for creating a vault credential store."
badFields[globals.AttributesAddressField] = "Field required for creating a vault credential store."
}
if attrs.GetToken().GetValue() == "" {
badFields[vaultTokenField] = "Field required for creating a vault credential store."
@ -842,9 +841,9 @@ func validateUpdateRequest(ctx context.Context, req *pbs.UpdateCredentialStoreRe
}
attrs := req.GetItem().GetVaultCredentialStoreAttributes()
if attrs != nil {
if handlers.MaskContains(req.GetUpdateMask().GetPaths(), addressField) &&
if handlers.MaskContains(req.GetUpdateMask().GetPaths(), globals.AttributesAddressField) &&
attrs.GetAddress().GetValue() == "" {
badFields[addressField] = "This is a required field and cannot be unset."
badFields[globals.AttributesAddressField] = "This is a required field and cannot be unset."
}
if handlers.MaskContains(req.GetUpdateMask().GetPaths(), vaultTokenField) &&
attrs.GetToken().GetValue() == "" {

@ -1035,7 +1035,7 @@ func TestUpdateVault(t *testing.T) {
{
name: "update connection info",
req: &pbs.UpdateCredentialStoreRequest{
UpdateMask: fieldmask("attributes.address", "attributes.client_certificate", "attributes.client_certificate_key", "attributes.ca_cert", "attributes.token"),
UpdateMask: fieldmask(globals.AttributesAddressField, "attributes.client_certificate", "attributes.client_certificate_key", "attributes.ca_cert", "attributes.token"),
Item: &pb.CredentialStore{
Attrs: &pb.CredentialStore_VaultCredentialStoreAttributes{
VaultCredentialStoreAttributes: &pb.VaultCredentialStoreAttributes{

@ -637,21 +637,22 @@ func validateCreateRequest(req *pbs.CreateHostRequest) error {
attrs := req.GetItem().GetStaticHostAttributes()
switch {
case attrs == nil:
badFields["attributes"] = "This is a required field."
badFields[globals.AttributesField] = "This is a required field."
default:
if attrs.GetAddress() == nil ||
len(attrs.GetAddress().GetValue()) < static.MinHostAddressLength ||
len(attrs.GetAddress().GetValue()) > static.MaxHostAddressLength {
badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength)
}
_, _, err := net.SplitHostPort(attrs.GetAddress().GetValue())
switch {
case err == nil:
badFields["attributes.address"] = "Address for static hosts does not support a port."
case strings.Contains(err.Error(), "missing port in address"):
// Bare hostname, which we want
default:
badFields["attributes.address"] = fmt.Sprintf("Error parsing address: %v.", err)
badFields[globals.AttributesAddressField] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength)
} else {
_, _, err := net.SplitHostPort(attrs.GetAddress().GetValue())
switch {
case err == nil:
badFields[globals.AttributesAddressField] = "Address for static hosts does not support a port."
case strings.Contains(err.Error(), globals.MissingPortErrStr):
// Bare hostname, which we want
default:
badFields[globals.AttributesAddressField] = fmt.Sprintf("Error parsing address: %v.", err)
}
}
}
case plugin.Subtype:
@ -668,17 +669,26 @@ func validateUpdateRequest(req *pbs.UpdateHostRequest) error {
case static.Subtype:
if req.GetItem().GetType() != "" && req.GetItem().GetType() != static.Subtype.String() {
badFields[globals.TypeField] = "Cannot modify the resource type."
}
if handlers.MaskContains(req.GetUpdateMask().GetPaths(), globals.AttributesAddressField) {
attrs := req.GetItem().GetStaticHostAttributes()
if handlers.MaskContains(req.GetUpdateMask().GetPaths(), "attributes.address") {
switch {
case attrs == nil:
badFields["attributes"] = "Attributes field not supplied request"
default:
if attrs.GetAddress() == nil ||
len(strings.TrimSpace(attrs.GetAddress().GetValue())) < static.MinHostAddressLength ||
len(strings.TrimSpace(attrs.GetAddress().GetValue())) > static.MaxHostAddressLength {
badFields["attributes.address"] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength)
switch {
case attrs == nil:
badFields[globals.AttributesField] = "Attributes field not supplied in request"
default:
if attrs.GetAddress() == nil ||
len(strings.TrimSpace(attrs.GetAddress().GetValue())) < static.MinHostAddressLength ||
len(strings.TrimSpace(attrs.GetAddress().GetValue())) > static.MaxHostAddressLength {
badFields[globals.AttributesAddressField] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength)
} else {
_, _, err := net.SplitHostPort(attrs.GetAddress().GetValue())
switch {
case err == nil:
badFields[globals.AttributesAddressField] = "Address for static hosts does not support a port."
case strings.Contains(err.Error(), globals.MissingPortErrStr):
// Bare hostname, which we want
default:
badFields[globals.AttributesAddressField] = fmt.Sprintf("Error parsing address: %v.", err)
}
}
}

@ -771,6 +771,21 @@ func TestCreate(t *testing.T) {
res: nil,
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Can't specify port",
req: &pbs.CreateHostRequest{Item: &pb.Host{
HostCatalogId: hc.GetPublicId(),
Name: &wrappers.StringValue{Value: "port name"},
Description: &wrappers.StringValue{Value: "port desc"},
Attrs: &pb.Host_StaticHostAttributes{
StaticHostAttributes: &pb.StaticHostAttributes{
Address: wrapperspb.String("123.456.789:12345"),
},
},
}},
res: nil,
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Can't specify Created Time",
req: &pbs.CreateHostRequest{Item: &pb.Host{
@ -798,11 +813,16 @@ func TestCreate(t *testing.T) {
got, gErr := s.CreateHost(auth.DisabledAuthTestContext(iamRepoFn, proj.GetPublicId()), tc.req)
if tc.err != nil {
require.Error(gErr)
assert.Nil(got)
assert.True(errors.Is(gErr, tc.err), "CreateHost(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)
if tc.wantErrContains != "" {
assert.Contains(gErr.Error(), tc.wantErrContains)
}
return
}
require.NoError(gErr)
require.NotNil(got)
if got != nil {
assert.Contains(got.GetUri(), tc.res.GetUri())
assert.True(strings.HasPrefix(got.GetItem().GetId(), static.HostPrefix))
@ -868,22 +888,21 @@ func TestUpdate_Static(t *testing.T) {
}
hCreated := h.GetCreateTime().GetTimestamp().AsTime()
toMerge := &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
}
tested, err := hosts.NewService(repoFn, pluginRepoFn)
require.NoError(t, err, "Failed to create a new host set service.")
cases := []struct {
name string
req *pbs.UpdateHostRequest
res *pbs.UpdateHostResponse
err error
name string
req *pbs.UpdateHostRequest
res *pbs.UpdateHostResponse
err error
wantErrContains string
}{
{
name: "Update an Existing Host",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{"name", "description", "type"},
},
@ -915,6 +934,7 @@ func TestUpdate_Static(t *testing.T) {
{
name: "Multiple Paths in single string",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{"name,description,type"},
},
@ -946,16 +966,19 @@ func TestUpdate_Static(t *testing.T) {
{
name: "No Update Mask",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
Item: &pb.Host{
Name: &wrappers.StringValue{Value: "updated name"},
Description: &wrappers.StringValue{Value: "updated desc"},
},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
wantErrContains: "UpdateMask not provided",
},
{
name: "No Update Mask",
name: "Changing Type",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{"name,type"},
},
@ -964,33 +987,39 @@ func TestUpdate_Static(t *testing.T) {
Type: "ec2",
},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
wantErrContains: "Cannot modify the resource type",
},
{
name: "Empty Path",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{Paths: []string{}},
Item: &pb.Host{
Name: &wrappers.StringValue{Value: "updated name"},
Description: &wrappers.StringValue{Value: "updated desc"},
},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
wantErrContains: "No valid fields provided",
},
{
name: "Only non-existant paths in Mask",
name: "Only non-existent paths in Mask",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{Paths: []string{"nonexistant_field"}},
Item: &pb.Host{
Name: &wrappers.StringValue{Value: "updated name"},
Description: &wrappers.StringValue{Value: "updated desc"},
},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
wantErrContains: "No valid fields provided",
},
{
name: "Unset Name",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{"name"},
},
@ -1019,6 +1048,7 @@ func TestUpdate_Static(t *testing.T) {
{
name: "Unset Description",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{"description"},
},
@ -1047,6 +1077,7 @@ func TestUpdate_Static(t *testing.T) {
{
name: "Update Only Name",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{"name"},
},
@ -1077,6 +1108,7 @@ func TestUpdate_Static(t *testing.T) {
{
name: "Update Only Description",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{"description"},
},
@ -1117,12 +1149,13 @@ func TestUpdate_Static(t *testing.T) {
Description: &wrappers.StringValue{Value: "desc"},
},
},
err: handlers.ApiErrorWithCode(codes.NotFound),
err: handlers.ApiErrorWithCode(codes.NotFound),
wantErrContains: "Resource not found",
},
{
name: "Cant change Id",
req: &pbs.UpdateHostRequest{
Id: hc.GetPublicId(),
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{"id"},
},
@ -1133,15 +1166,15 @@ func TestUpdate_Static(t *testing.T) {
Description: &wrappers.StringValue{Value: "new desc"},
},
},
res: nil,
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
wantErrContains: "This is a read only field",
},
{
name: "Cant unset address",
req: &pbs.UpdateHostRequest{
Id: hc.GetPublicId(),
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{"attributes.address"},
Paths: []string{globals.AttributesAddressField},
},
Item: &pb.Host{
Attrs: &pb.Host_StaticHostAttributes{
@ -1151,15 +1184,15 @@ func TestUpdate_Static(t *testing.T) {
},
},
},
res: nil,
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
wantErrContains: "Address length must be",
},
{
name: "Cant set address to empty string",
req: &pbs.UpdateHostRequest{
Id: hc.GetPublicId(),
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{"attributes.address"},
Paths: []string{globals.AttributesAddressField},
},
Item: &pb.Host{
Attrs: &pb.Host_StaticHostAttributes{
@ -1169,12 +1202,33 @@ func TestUpdate_Static(t *testing.T) {
},
},
},
res: nil,
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
wantErrContains: "Address length must be",
},
{
name: "Cant specify port in address",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{globals.AttributesAddressField},
},
Item: &pb.Host{
Name: &wrappers.StringValue{Value: "port name"},
Description: &wrappers.StringValue{Value: "port desc"},
Attrs: &pb.Host_StaticHostAttributes{
StaticHostAttributes: &pb.StaticHostAttributes{
Address: wrapperspb.String("123.456.789:12345"),
},
},
},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
wantErrContains: "does not support a port",
},
{
name: "Cant specify Created Time",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{"created_time"},
},
@ -1182,12 +1236,13 @@ func TestUpdate_Static(t *testing.T) {
CreatedTime: timestamppb.Now(),
},
},
res: nil,
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
wantErrContains: "This is a read only field",
},
{
name: "Cant specify Updated Time",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{"updated_time"},
},
@ -1195,12 +1250,13 @@ func TestUpdate_Static(t *testing.T) {
UpdatedTime: timestamppb.Now(),
},
},
res: nil,
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
wantErrContains: "This is a read only field",
},
{
name: "Valid mask, cant specify type",
req: &pbs.UpdateHostRequest{
Id: h.GetPublicId(),
UpdateMask: &field_mask.FieldMask{
Paths: []string{"name"},
},
@ -1208,8 +1264,8 @@ func TestUpdate_Static(t *testing.T) {
Type: "Unknown",
},
},
res: nil,
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
wantErrContains: "Cannot modify the resource type",
},
}
for _, tc := range cases {
@ -1217,8 +1273,7 @@ func TestUpdate_Static(t *testing.T) {
assert, require := assert.New(t), require.New(t)
tc.req.Item.Version = version
req := proto.Clone(toMerge).(*pbs.UpdateHostRequest)
proto.Merge(req, tc.req)
req := tc.req
// Test some bad versions
req.Item.Version = version + 2
@ -1232,26 +1287,26 @@ func TestUpdate_Static(t *testing.T) {
got, gErr := tested.UpdateHost(auth.DisabledAuthTestContext(iamRepoFn, proj.GetPublicId()), req)
if tc.err != nil {
require.Error(gErr)
assert.Nil(got)
assert.True(errors.Is(gErr, tc.err), "UpdateHost(%+v) got error %v, wanted %v", req, gErr, tc.err)
if tc.wantErrContains != "" {
assert.Contains(gErr.Error(), tc.wantErrContains)
}
return
}
if tc.err == nil {
defer resetHost()
}
defer resetHost()
require.NoError(gErr)
require.NotNil(got)
if got != nil {
assert.NotNilf(tc.res, "Expected UpdateHost response to be nil, but was %v", got)
gotUpdateTime := got.GetItem().GetUpdatedTime().AsTime()
// Verify it is a set updated after it was created
// TODO: This is currently failing.
assert.True(gotUpdateTime.After(hCreated), "Updated set should have been updated after it's creation. Was updated %v, which is after %v", gotUpdateTime, hCreated)
require.NotNilf(tc.res, "Expected UpdateHost response to be nil, but was %v", got)
gotUpdateTime := got.GetItem().GetUpdatedTime().AsTime()
// Verify it is a set updated after it was created
assert.True(gotUpdateTime.After(hCreated), "Updated set should have been updated after its creation. Was updated %v, which is after %v", gotUpdateTime, hCreated)
// Clear all values which are hard to compare against.
got.Item.UpdatedTime, tc.res.Item.UpdatedTime = nil, nil
}
if tc.res != nil {
tc.res.Item.Version = version + 1
}
// Clear all values which are hard to compare against.
got.Item.UpdatedTime, tc.res.Item.UpdatedTime = nil, nil
tc.res.Item.Version = version + 1
assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "UpdateHost(%q) got response %q, wanted %q", req, got, tc.res)
})
}

@ -47,9 +47,8 @@ import (
)
const (
credentialDomain = "credential"
hostDomain = "host"
missingPortErrStr = "missing port in address"
credentialDomain = "credential"
hostDomain = "host"
)
// extraWorkerFilterFunc takes in a set of workers and returns another set,
@ -699,6 +698,10 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession
return nil, handlers.ForbiddenError()
}
if t.GetDefaultPort() == 0 {
return nil, handlers.ConflictErrorf("Target does not have default port defined.")
}
// Get the target information
repo, err := s.repoFn()
if err != nil {
@ -745,87 +748,102 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession
selectedWorkers[i], selectedWorkers[j] = selectedWorkers[j], selectedWorkers[i]
})
requestedId := req.GetHostId()
staticHostRepo, err := s.staticHostRepoFn()
if err != nil {
return nil, err
}
pluginHostRepo, err := s.pluginHostRepoFn()
if err != nil {
return nil, err
}
p := strconv.FormatUint(uint64(t.GetDefaultPort()), 10)
var h, hostId, hostSetId string
switch {
case t.GetAddress() != "":
h = t.GetAddress()
default:
requestedId := req.GetHostId()
staticHostRepo, err := s.staticHostRepoFn()
if err != nil {
return nil, err
}
pluginHostRepo, err := s.pluginHostRepoFn()
if err != nil {
return nil, err
}
var pluginHostSetIds []string
var endpoints []*host.Endpoint
for _, hSource := range hostSources {
hsId := hSource.Id()
// FIXME: read in type from DB rather than rely on prefix
switch subtypes.SubtypeFromId(hostDomain, hsId) {
case static.Subtype:
eps, err := staticHostRepo.Endpoints(ctx, hsId)
var pluginHostSetIds []string
var endpoints []*host.Endpoint
for _, hSource := range hostSources {
hsId := hSource.Id()
switch subtypes.SubtypeFromId(hostDomain, hsId) {
case static.Subtype:
eps, err := staticHostRepo.Endpoints(ctx, hsId)
if err != nil {
return nil, err
}
endpoints = append(endpoints, eps...)
default:
// Batch the plugin host set ids since each round trip to the plugin
// has the potential to be expensive.
pluginHostSetIds = append(pluginHostSetIds, hsId)
}
}
if len(pluginHostSetIds) > 0 {
eps, err := pluginHostRepo.Endpoints(ctx, pluginHostSetIds)
if err != nil {
return nil, err
}
endpoints = append(endpoints, eps...)
default:
// Batch the plugin host set ids since each round trip to the plugin
// has the potential to be expensive.
pluginHostSetIds = append(pluginHostSetIds, hsId)
}
}
if len(pluginHostSetIds) > 0 {
eps, err := pluginHostRepo.Endpoints(ctx, pluginHostSetIds)
if err != nil {
return nil, err
}
endpoints = append(endpoints, eps...)
}
if len(endpoints) == 0 && t.GetAddress() == "" {
return nil, handlers.NotFoundErrorf("No host sources or address found for given target.")
}
if len(endpoints) == 0 {
return nil, handlers.NotFoundErrorf("No host sources or address found for given target.")
}
var chosenEndpoint *host.Endpoint
if requestedId != "" {
for _, ep := range endpoints {
if ep.HostId == requestedId {
chosenEndpoint = ep
var chosenEndpoint *host.Endpoint
if requestedId != "" {
for _, ep := range endpoints {
if ep.HostId == requestedId {
chosenEndpoint = ep
}
}
if chosenEndpoint == nil {
// We didn't find it
return nil, handlers.InvalidArgumentErrorf(
"Errors in provided fields.",
map[string]string{
"host_id": "The requested host id is not available.",
})
}
}
if chosenEndpoint == nil {
// We didn't find it
return nil, handlers.InvalidArgumentErrorf(
"Errors in provided fields.",
map[string]string{
"host_id": "The requested host id is not available.",
})
chosenEndpoint = endpoints[rand.Intn(len(endpoints))]
}
}
var h, p, hostId, hostSetId string
if chosenEndpoint == nil && len(endpoints) > 0 {
chosenEndpoint = endpoints[rand.Intn(len(endpoints))]
}
if chosenEndpoint != nil {
hostId = chosenEndpoint.HostId
hostSetId = chosenEndpoint.SetId
h, p, err = net.SplitHostPort(chosenEndpoint.Address)
switch {
case err != nil && strings.Contains(err.Error(), missingPortErrStr):
if t.GetDefaultPort() == 0 {
return nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("neither the selected host %q nor the target provides a port to use", chosenEndpoint.HostId))
}
h = chosenEndpoint.Address
p = strconv.FormatUint(uint64(t.GetDefaultPort()), 10)
case err != nil:
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error when parsing the chosen endpoints host address"))
}
h = chosenEndpoint.Address
}
if t.GetAddress() != "" && hostId == "" && hostSetId == "" {
h = t.GetAddress()
p = strconv.FormatUint(uint64(t.GetDefaultPort()), 10)
if h == "" {
return nil, handlers.ApiErrorWithCodeAndMessage(
codes.FailedPrecondition,
"No host was discovered after checking target address and host sources.")
}
// Ensure we don't have a port from the address, which would be unexpected
// FIXME: We've decided to hold off on making this an error until 0.14. In
// the meantime, ignore any port coming from the host address.
hostWithoutPort, _, err := net.SplitHostPort(h)
switch {
case err != nil && strings.Contains(err.Error(), globals.MissingPortErrStr):
// This is what we expect
case err != nil:
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error when parsing the chosen endpoint host address"))
case err == nil:
h = hostWithoutPort
// Use below logic for 0.14:
/*
return nil, handlers.ApiErrorWithCodeAndMessage(
codes.FailedPrecondition,
"Address specified for use unexpectedly contains a port.")
*/
}
// Generate the endpoint URL
@ -1566,7 +1584,7 @@ func validateCreateRequest(req *pbs.CreateTargetRequest) error {
switch {
case err == nil:
badFields[globals.AddressField] = "Address does not support a port."
case strings.Contains(err.Error(), missingPortErrStr):
case strings.Contains(err.Error(), globals.MissingPortErrStr):
default:
badFields[globals.AddressField] = fmt.Sprintf("Error parsing address: %v.", err)
}
@ -1642,7 +1660,7 @@ func validateUpdateRequest(req *pbs.UpdateTargetRequest) error {
switch {
case err == nil:
badFields[globals.AddressField] = "Address does not support a port."
case strings.Contains(err.Error(), missingPortErrStr):
case strings.Contains(err.Error(), globals.MissingPortErrStr):
default:
badFields[globals.AddressField] = fmt.Sprintf("Error parsing address: %v.", err)
}

@ -2494,6 +2494,7 @@ func TestAuthorizeSession(t *testing.T) {
hWithPort := static.TestHosts(t, conn, hcWithPort.GetPublicId(), 1)[0]
shsWithPort := static.TestSets(t, conn, hcWithPort.GetPublicId(), 1)[0]
_ = static.TestSetMembers(t, conn, shsWithPort.GetPublicId(), []*static.Host{hWithPort})
hWithPortBareAddress := hWithPort.GetAddress()
hWithPort.Address = fmt.Sprintf("%s:54321", hWithPort.GetAddress())
hWithPort, _, err = staticRepo.UpdateHost(ctx, hcWithPort.GetProjectId(), hWithPort, hWithPort.GetVersion(), []string{"address"})
require.NoError(t, err)
@ -2551,7 +2552,7 @@ func TestAuthorizeSession(t *testing.T) {
hostSourceId: shsWithPort.GetPublicId(),
credSourceId: clsResp.GetItem().GetId(),
wantedHostId: hWithPort.GetPublicId(),
wantedEndpoint: hWithPort.GetAddress(),
wantedEndpoint: fmt.Sprintf("%s:%d", hWithPortBareAddress, defaultPort),
},
{
name: "plugin host",
@ -3473,9 +3474,8 @@ func TestAuthorizeSession_Errors(t *testing.T) {
setup: []func(tcpTarget target.Target) uint32{workerExists, hostExists, libraryExists},
},
{
name: "no port",
name: "no host port",
setup: []func(tcpTarget target.Target) uint32{workerExists, hostWithoutPort, libraryExists},
err: true,
},
{
name: "no hosts",
@ -3495,7 +3495,7 @@ func TestAuthorizeSession_Errors(t *testing.T) {
}
for i, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), fmt.Sprintf("test-%d", i))
tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), fmt.Sprintf("test-%d", i), target.WithDefaultPort(22))
for _, fn := range tc.setup {
ver := fn(tar)

@ -774,7 +774,7 @@ func (s Service) toProto(ctx context.Context, in *server.Worker, opt ...handlers
if outputFields.Has(globals.ScopeIdField) {
out.ScopeId = in.GetScopeId()
}
if outputFields.Has(globals.DirectlyConnectedDownstreamWorkers) {
if outputFields.Has(globals.DirectlyConnectedDownstreamWorkersField) {
out.DirectlyConnectedDownstreamWorkers = downstreamWorkers(ctx, in.GetPublicId(), s.downstreams)
}
if outputFields.Has(globals.DescriptionField) && in.GetDescription() != "" {

@ -49,7 +49,7 @@ func (w *Worker) StartControllerConnections() error {
initialAddrs = append(initialAddrs, addr)
default:
host, port, err := net.SplitHostPort(addr)
if err != nil && strings.Contains(err.Error(), "missing port in address") {
if err != nil && strings.Contains(err.Error(), globals.MissingPortErrStr) {
host, port, err = net.SplitHostPort(net.JoinHostPort(addr, "9201"))
}
if err != nil {

Loading…
Cancel
Save