From a65d5d8573a49d696d9f91bb123a869da24f0ab2 Mon Sep 17 00:00:00 2001 From: Hugo <10965479+hugoghx@users.noreply.github.com> Date: Tue, 13 May 2025 16:55:28 +0100 Subject: [PATCH] feat: Normalize fields across various Boundary components (#5599) Normalizes various IP/Address/Host fields across Boundary resources/components to comply with IPv6 specifications. --- globals/errors.go | 6 - internal/auth/ldap/auth_method.go | 7 +- internal/auth/ldap/auth_method_test.go | 15 +- internal/auth/ldap/testing.go | 4 + internal/cmd/base/dev.go | 6 +- internal/cmd/base/listener.go | 10 +- internal/cmd/base/server_test.go | 55 +++- internal/cmd/base/servers.go | 21 +- internal/cmd/commands/connect/connect.go | 4 +- internal/cmd/commands/dev/dev.go | 2 +- internal/cmd/commands/server/server.go | 4 +- internal/cmd/config/config.go | 45 ++- internal/cmd/config/config_test.go | 138 +++++---- internal/credential/vault/testing.go | 2 +- .../handlers/accounts/account_service.go | 13 +- .../handlers/accounts/account_service_test.go | 54 ++++ .../authmethods/authmethod_service.go | 8 +- .../controller/handlers/authmethods/oidc.go | 9 + .../handlers/authmethods/oidc_test.go | 70 ++++- .../credentialstore_service.go | 8 + .../credentialstore_service_test.go | 59 ++++ .../controller/handlers/hosts/host_service.go | 25 +- .../handlers/hosts/host_service_test.go | 65 +++++ .../handlers/targets/target_service.go | 26 +- .../targets/tcp/target_service_test.go | 271 +++++++++++++++++- .../daemon/worker/controller_connection.go | 2 +- internal/daemon/worker/handler.go | 2 +- internal/host/plugin/job_set_sync_test.go | 5 +- internal/host/static/repository_host_test.go | 66 +---- internal/ratelimit/handler_test.go | 32 +-- internal/session/session.go | 2 +- internal/session/session_connect_with_test.go | 6 +- .../target/tcp/repository_tcp_target_test.go | 42 ++- internal/tests/api/targets/target_test.go | 18 +- .../cluster/parallel/unix_listener_test.go | 16 ++ internal/util/net.go | 123 ++++++-- internal/util/net_test.go | 63 ++-- 37 files changed, 992 insertions(+), 312 deletions(-) delete mode 100644 globals/errors.go diff --git a/globals/errors.go b/globals/errors.go deleted file mode 100644 index bbab51d04a..0000000000 --- a/globals/errors.go +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -package globals - -const MissingPortErrStr = "missing port in address" diff --git a/internal/auth/ldap/auth_method.go b/internal/auth/ldap/auth_method.go index 4ae52c2f34..b81988920f 100644 --- a/internal/auth/ldap/auth_method.go +++ b/internal/auth/ldap/auth_method.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/oplog" "github.com/hashicorp/boundary/internal/types/resource" + "github.com/hashicorp/go-secure-stdlib/parseutil" "google.golang.org/protobuf/proto" ) @@ -209,7 +210,11 @@ func (am *AuthMethod) convertUrls(ctx context.Context) ([]*Url, error) { } newValObjs := make([]*Url, 0, len(am.Urls)) for priority, u := range am.Urls { - parsed, err := url.Parse(u) + addr, err := parseutil.NormalizeAddr(u) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + parsed, err := url.Parse(addr) if err != nil { return nil, errors.Wrap(ctx, err, op) } diff --git a/internal/auth/ldap/auth_method_test.go b/internal/auth/ldap/auth_method_test.go index 7fd6d36102..16b9b1e973 100644 --- a/internal/auth/ldap/auth_method_test.go +++ b/internal/auth/ldap/auth_method_test.go @@ -343,7 +343,7 @@ func TestAuthMethod_oplog(t *testing.T) { func Test_convertValueObjects(t *testing.T) { testCtx := context.TODO() testPublicId := "test-id" - testLdapServers := []string{"ldaps://ldap1.alice.com", "ldaps://ldap2.alice.com"} + testLdapServers := []string{"ldaps://ldap1.alice.com", "ldaps://ldap2.alice.com", "ldap://[2001:BEEF:0:0:0:1:0:0001]:80"} _, pem := TestGenerateCA(t, "localhost") testCerts := []string{pem} c, err := NewCertificate(testCtx, testPublicId, pem) @@ -499,7 +499,18 @@ func Test_convertValueObjects(t *testing.T) { }, }, wantErrMatch: errors.T(errors.Unknown), - wantErrContains: "first path segment in URL cannot contain colon", + wantErrContains: "failed to parse address", + }, + { + name: "invalid-url-has-invalid-ipv6", + am: &AuthMethod{ + AuthMethod: &store.AuthMethod{ + PublicId: testPublicId, + Urls: []string{"ldaps://[2001:BEEF:0:0:1:0:0001]"}, + }, + }, + wantErrMatch: errors.T(errors.Unknown), + wantErrContains: "host contains an invalid IPv6 literal", }, { name: "invalid-client-cert", diff --git a/internal/auth/ldap/testing.go b/internal/auth/ldap/testing.go index c33bf5f7d8..eeb9f059d9 100644 --- a/internal/auth/ldap/testing.go +++ b/internal/auth/ldap/testing.go @@ -21,6 +21,7 @@ import ( "github.com/hashicorp/boundary/internal/db" wrapping "github.com/hashicorp/go-kms-wrapping/v2" + "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/stretchr/testify/require" ) @@ -255,6 +256,9 @@ func TestConvertToUrls(t testing.TB, urls ...string) []*url.URL { require.NotEmpty(urls) var convertedUrls []*url.URL for _, u := range urls { + var err error + u, err = parseutil.NormalizeAddr(u) + require.NoError(err) parsed, err := url.Parse(u) require.NoError(err) require.Contains([]string{"ldap", "ldaps"}, parsed.Scheme) diff --git a/internal/cmd/base/dev.go b/internal/cmd/base/dev.go index e8881d5d3e..d4f4e2435e 100644 --- a/internal/cmd/base/dev.go +++ b/internal/cmd/base/dev.go @@ -244,7 +244,7 @@ func (b *Server) CreateDevLdapAuthMethod(ctx context.Context) error { continue } host, _, err = util.SplitHostPort(ln.Config.Address) - if err != nil { + if err != nil && !errors.Is(err, util.ErrMissingPort) { return fmt.Errorf("error splitting host/port: %w", err) } } @@ -262,7 +262,7 @@ func (b *Server) CreateDevLdapAuthMethod(ctx context.Context) error { // added back, otherwise the gldap server will fail to start due to a parsing // error. if ip := net.ParseIP(host); ip != nil { - if ip.To16() != nil { + if ip.To4() == nil && ip.To16() != nil { host = fmt.Sprintf("[%s]", host) } } @@ -463,7 +463,7 @@ func (b *Server) CreateDevOidcAuthMethod(ctx context.Context) error { continue } b.DevOidcSetup.hostAddr, b.DevOidcSetup.callbackPort, err = util.SplitHostPort(ln.Config.Address) - if err != nil { + if err != nil && !errors.Is(err, util.ErrMissingPort) { return fmt.Errorf("error splitting host/port: %w", err) } if b.DevOidcSetup.callbackPort == "" { diff --git a/internal/cmd/base/listener.go b/internal/cmd/base/listener.go index bfaad009de..793cd04060 100644 --- a/internal/cmd/base/listener.go +++ b/internal/cmd/base/listener.go @@ -18,6 +18,7 @@ import ( "github.com/hashicorp/boundary/internal/util" "github.com/hashicorp/go-secure-stdlib/listenerutil" + "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/reloadutil" "github.com/mitchellh/cli" "github.com/pires/go-proxyproto" @@ -139,7 +140,7 @@ func tcpListenerFactory(purpose string, l *listenerutil.ListenerConfig, ui cli.U } host, port, err := util.SplitHostPort(l.Address) - if err != nil { + if err != nil && !errors.Is(err, util.ErrMissingPort) { return "", nil, fmt.Errorf("error splitting host/port: %w", err) } if port == "" { @@ -173,10 +174,15 @@ func tcpListenerFactory(purpose string, l *listenerutil.ListenerConfig, ui cli.U } if l.RandomPort { - port = "" + port = "0" // net.Listen will choose an available port automatically. Used for tests. } finalListenAddr := net.JoinHostPort(host, port) + normalizedListenAddr, err := parseutil.NormalizeAddr(finalListenAddr) + if err != nil { + return "", nil, fmt.Errorf("failed to normalize final listen addr %q: %w", finalListenAddr, err) + } + finalListenAddr = normalizedListenAddr ln, err := net.Listen(bindProto, finalListenAddr) if err != nil { diff --git a/internal/cmd/base/server_test.go b/internal/cmd/base/server_test.go index 3fbb20b0ef..94f6fbadcf 100644 --- a/internal/cmd/base/server_test.go +++ b/internal/cmd/base/server_test.go @@ -525,6 +525,35 @@ func TestSetupWorkerPublicAddress(t *testing.T) { expErrStr: "", expPublicAddress: "127.0.0.1:8080", }, + { + name: "setting public address directly with invalid ipv6", + inputConfig: &config.Config{ + SharedConfig: &configutil.SharedConfig{ + Listeners: []*listenerutil.ListenerConfig{}, + }, + Worker: &config.Worker{ + PublicAddr: "[2001:4860:4860:0:0:0:8888]", + }, + }, + inputFlagValue: "", + expErr: true, + expErrStr: "Error normalizing worker address", + }, + { + name: "setting public address directly with ipv6 but no brackets", + inputConfig: &config.Config{ + SharedConfig: &configutil.SharedConfig{ + Listeners: []*listenerutil.ListenerConfig{}, + }, + Worker: &config.Worker{ + PublicAddr: "2001:4860:4860:0:0:0:0:8888", + }, + }, + inputFlagValue: "", + expErr: false, + expErrStr: "", + expPublicAddress: "[2001:4860:4860::8888]:9202", + }, { name: "setting public address directly with ipv6", inputConfig: &config.Config{ @@ -532,13 +561,13 @@ func TestSetupWorkerPublicAddress(t *testing.T) { Listeners: []*listenerutil.ListenerConfig{}, }, Worker: &config.Worker{ - PublicAddr: "[2001:4860:4860:0:0:0:0:8888]", + PublicAddr: "2001:4860:4860:0:0:0:0:8888", }, }, inputFlagValue: "", expErr: false, expErrStr: "", - expPublicAddress: "[2001:4860:4860:0:0:0:0:8888]:9202", + expPublicAddress: "[2001:4860:4860::8888]:9202", }, { name: "setting public address directly with ipv6:port", @@ -553,7 +582,7 @@ func TestSetupWorkerPublicAddress(t *testing.T) { inputFlagValue: "", expErr: false, expErrStr: "", - expPublicAddress: "[2001:4860:4860:0:0:0:0:8888]:8080", + expPublicAddress: "[2001:4860:4860::8888]:8080", }, { name: "setting public address directly with abbreviated ipv6", @@ -562,7 +591,7 @@ func TestSetupWorkerPublicAddress(t *testing.T) { Listeners: []*listenerutil.ListenerConfig{}, }, Worker: &config.Worker{ - PublicAddr: "[2001:4860:4860::8888]", + PublicAddr: "2001:4860:4860::8888", }, }, inputFlagValue: "", @@ -781,6 +810,20 @@ func TestSetupWorkerPublicAddress(t *testing.T) { expErrStr: "", expPublicAddress: ":9202", }, + { + name: "read unix address from listeners ip only", + inputConfig: &config.Config{ + SharedConfig: &configutil.SharedConfig{ + Listeners: []*listenerutil.ListenerConfig{ + {Purpose: []string{"proxy"}, Address: "someaddr", Type: "unix"}, + }, + }, + Worker: &config.Worker{}, + }, + expErr: false, + expErrStr: "", + expPublicAddress: "someaddr:9202", + }, { name: "using flag value to point to nonexistent file", inputConfig: &config.Config{ @@ -802,9 +845,9 @@ func TestSetupWorkerPublicAddress(t *testing.T) { }, Worker: &config.Worker{}, }, - inputFlagValue: "abc::123", + inputFlagValue: "abc::123:::", expErr: true, - expErrStr: "Error splitting public adddress host/port: address abc::123: too many colons in address", + expErrStr: "Error splitting public adddress host/port: too many colons in address", expPublicAddress: "", }, { diff --git a/internal/cmd/base/servers.go b/internal/cmd/base/servers.go index 351f8a6f81..39cd080d4c 100644 --- a/internal/cmd/base/servers.go +++ b/internal/cmd/base/servers.go @@ -812,12 +812,16 @@ func (b *Server) SetupWorkerPublicAddress(conf *config.Config, flagValue string) if flagValue != "" { conf.Worker.PublicAddr = flagValue } + isUnixListener := false if conf.Worker.PublicAddr == "" { FindAddr: for _, listener := range conf.Listeners { for _, purpose := range listener.Purpose { if purpose == "proxy" { conf.Worker.PublicAddr = listener.Address + if strings.EqualFold(listener.Type, "unix") { + isUnixListener = true + } break FindAddr } } @@ -836,14 +840,29 @@ func (b *Server) SetupWorkerPublicAddress(conf *config.Config, flagValue string) } host, port, err := util.SplitHostPort(conf.Worker.PublicAddr) - if err != nil { + if err != nil && !errors.Is(err, util.ErrMissingPort) { return fmt.Errorf("Error splitting public adddress host/port: %w", err) } + if host != "" { + if host, err = parseutil.NormalizeAddr(host); err != nil { + return fmt.Errorf("Error normalizing worker address") + } + } if port == "" { port = "9202" } conf.Worker.PublicAddr = util.JoinHostPort(host, port) + if host != "" && !isUnixListener { + // NormalizeAddr requires that a host be present, but that is not + // guaranteed in this code path. Additionally, if no host is present, + // there's no need to normalize. + conf.Worker.PublicAddr, err = parseutil.NormalizeAddr(conf.Worker.PublicAddr) + if err != nil { + return fmt.Errorf("Failed to normalize worker public adddress: %w", err) + } + } + return nil } diff --git a/internal/cmd/commands/connect/connect.go b/internal/cmd/commands/connect/connect.go index 986e0c6886..d5e569ac4f 100644 --- a/internal/cmd/commands/connect/connect.go +++ b/internal/cmd/commands/connect/connect.go @@ -475,7 +475,7 @@ func (c *Command) Run(args []string) (retCode int) { proxyAddr := clientProxy.ListenerAddress(context.Background()) var clientProxyHost, clientProxyPort string clientProxyHost, clientProxyPort, err = util.SplitHostPort(proxyAddr) - if err != nil { + if err != nil && !errors.Is(err, util.ErrMissingPort) { c.PrintCliError(fmt.Errorf("error splitting listener addr: %w", err)) return base.CommandCliError } @@ -600,7 +600,7 @@ func (c *Command) handleExec(clientProxy *apiproxy.ClientProxy, passthroughArgs var host, port string var err error host, port, err = util.SplitHostPort(addr) - if err != nil { + if err != nil && !errors.Is(err, util.ErrMissingPort) { c.PrintCliError(fmt.Errorf("Error splitting listener addr: %w", err)) c.execCmdReturnValue.Store(int32(3)) return diff --git a/internal/cmd/commands/dev/dev.go b/internal/cmd/commands/dev/dev.go index 6edf736970..3596c2c888 100644 --- a/internal/cmd/commands/dev/dev.go +++ b/internal/cmd/commands/dev/dev.go @@ -593,7 +593,7 @@ func (c *Command) Run(args []string) int { } host, port, err := util.SplitHostPort(c.flagHostAddress) - if err != nil { + if err != nil && !errors.Is(err, util.ErrMissingPort) { c.UI.Error(fmt.Errorf("Invalid host address specified: %w", err).Error()) return base.CommandUserError } diff --git a/internal/cmd/commands/server/server.go b/internal/cmd/commands/server/server.go index e27e671de8..6134de6d66 100644 --- a/internal/cmd/commands/server/server.go +++ b/internal/cmd/commands/server/server.go @@ -358,7 +358,7 @@ func (c *Command) Run(args []string) int { } for _, upstream := range c.Config.Worker.InitialUpstreams { host, _, err := util.SplitHostPort(upstream) - if err != nil { + if err != nil && !errors.Is(err, util.ErrMissingPort) { c.UI.Error(fmt.Errorf("Invalid worker upstream address %q: %w", upstream, err).Error()) return base.CommandUserError } @@ -412,7 +412,7 @@ func (c *Command) Run(args []string) int { continue } host, _, err := util.SplitHostPort(ln.Address) - if err != nil { + if err != nil && !errors.Is(err, util.ErrMissingPort) { c.UI.Error(fmt.Errorf("Invalid cluster listener address %q: %w", ln.Address, err).Error()) return base.CommandUserError } diff --git a/internal/cmd/config/config.go b/internal/cmd/config/config.go index 388f045182..c66529b875 100644 --- a/internal/cmd/config/config.go +++ b/internal/cmd/config/config.go @@ -139,7 +139,7 @@ kms "aead" { } listener "tcp" { - address = "[::1]" + address = "::1" purpose = "api" tls_disable = true cors_enabled = true @@ -147,12 +147,12 @@ listener "tcp" { } listener "tcp" { - address = "[::1]" + address = "::1" purpose = "cluster" } listener "tcp" { - address = "[::1]" + address = "::1" purpose = "ops" tls_disable = true } @@ -160,15 +160,15 @@ listener "tcp" { devIpv6WorkerExtraConfig = ` listener "tcp" { - address = "[::1]" + address = "::1" purpose = "proxy" } worker { name = "w_1234567890" description = "A default worker created in dev mode" - public_addr = "[::1]" - initial_upstreams = ["[::1]"] + public_addr = "::1" + initial_upstreams = ["::1"] tags { type = ["dev", "local"] } @@ -1240,14 +1240,13 @@ func parseWorkerUpstreams(c *Config) ([]string, error) { return nil, nil } + upstreams := make([]string, 0) switch t := c.Worker.InitialUpstreamsRaw.(type) { case []any: - var upstreams []string err := mapstructure.WeakDecode(c.Worker.InitialUpstreamsRaw, &upstreams) if err != nil { return nil, fmt.Errorf("failed to decode worker initial_upstreams block into config field: %w", err) } - return upstreams, nil case string: upstreamsStr, err := parseutil.ParsePath(t) @@ -1255,17 +1254,25 @@ func parseWorkerUpstreams(c *Config) ([]string, error) { return nil, fmt.Errorf("bad env var or file pointer: %w", err) } - var upstreams []string err = json.Unmarshal([]byte(upstreamsStr), &upstreams) if err != nil { return nil, fmt.Errorf("failed to unmarshal env/file contents: %w", err) } - return upstreams, nil default: typ := reflect.TypeOf(t) return nil, fmt.Errorf("unexpected type %q", typ.String()) } + + for i := range upstreams { + normalized, err := parseutil.NormalizeAddr(upstreams[i]) + if err != nil { + return nil, fmt.Errorf("failed to normalize worker upstream %q: %w", upstreams[i], err) + } + upstreams[i] = normalized + } + + return upstreams, nil } func parseEventing(eventObj *ast.ObjectItem) (*event.EventerConfig, error) { @@ -1379,12 +1386,16 @@ func (c *Config) SetupControllerPublicClusterAddress(flagValue string) error { if flagValue != "" { c.Controller.PublicClusterAddr = flagValue } + isUnixListener := false if c.Controller.PublicClusterAddr == "" { FindAddr: for _, listener := range c.Listeners { for _, purpose := range listener.Purpose { if purpose == "cluster" { c.Controller.PublicClusterAddr = listener.Address + if strings.EqualFold(listener.Type, "unix") { + isUnixListener = true + } break FindAddr } } @@ -1403,7 +1414,7 @@ func (c *Config) SetupControllerPublicClusterAddress(flagValue string) error { } host, port, err := util.SplitHostPort(c.Controller.PublicClusterAddr) - if err != nil { + if err != nil && !errors.Is(err, util.ErrMissingPort) { return fmt.Errorf("Error splitting public cluster adddress host/port: %w", err) } if port == "" { @@ -1411,6 +1422,16 @@ func (c *Config) SetupControllerPublicClusterAddress(flagValue string) error { } c.Controller.PublicClusterAddr = util.JoinHostPort(host, port) + if host != "" && !isUnixListener { + // NormalizeAddr requires that a host be present, but that is not + // guaranteed in this code path. Additionally, if no host is present, + // there's no need to normalize. + c.Controller.PublicClusterAddr, err = parseutil.NormalizeAddr(c.Controller.PublicClusterAddr) + if err != nil { + return fmt.Errorf("Failed to normalize controller public cluster adddress: %w", err) + } + } + return nil } @@ -1464,7 +1485,7 @@ func (c *Config) SetupWorkerInitialUpstreams() error { } // Best effort see if it's a domain name and if not assume it must match host, _, err := util.SplitHostPort(c.Worker.InitialUpstreams[0]) - if err == nil { + if err == nil || errors.Is(err, util.ErrMissingPort) { ip := net.ParseIP(host) if ip == nil { // Assume it's a domain name diff --git a/internal/cmd/config/config_test.go b/internal/cmd/config/config_test.go index c5b01f5e1d..6de0c92354 100644 --- a/internal/cmd/config/config_test.go +++ b/internal/cmd/config/config_test.go @@ -781,98 +781,112 @@ func TestDevWorkerRecordingStoragePath(t *testing.T) { } } +// TestDevControllerIpv6 validates that all listeners use an IPv6 address when +// the WithIPv6Enabled(true) option is passed into DevController. Other dev +// controller configurations are validated in TestDevController. func TestDevControllerIpv6(t *testing.T) { - require, assert := require.New(t), assert.New(t) - // This test only validates that all listeners are utilizing an IPv6 address. - // Other dev controller configurations are validates in TestDevController. + require := require.New(t) + actual, err := DevController(WithIPv6Enabled(true)) require.NoError(err) - // expected an error here because we purposely did not provide a port number - // to allow randomly assigned port values + // Expected an error here because PublicClusterAddr is not set. _, _, err = net.SplitHostPort(actual.Controller.PublicClusterAddr) require.Error(err) - // assert the square brackets are removed from the host ipv6 address and that the port value is empty + // Same here. publicAddr, port, err := util.SplitHostPort(actual.Controller.PublicClusterAddr) - require.NoError(err) - assert.Empty(port) - assert.Empty(publicAddr) + require.ErrorIs(err, util.ErrMissingPort) + require.Empty(port) + require.Empty(publicAddr) require.NotEmpty(actual.Listeners) for _, l := range actual.Listeners { addr, _, err := util.SplitHostPort(l.Address) - require.NoError(err) + require.ErrorIs(err, util.ErrMissingPort) + require.NotEmpty(t, addr) + ip := net.ParseIP(addr) - assert.NotNil(ip, "failed to parse listener address for %v", l.Purpose) - assert.NotNil(ip.To16(), "failed to convert address to IPv6 for %v, found %v", l.Purpose, addr) + require.NotNil(ip, "failed to parse listener address for %v", l.Purpose) + require.NotNil(ip.To16(), "failed to convert address to IPv6 for %v, found %v", l.Purpose, addr) } } +// TestDevWorkerIpv6 validates that all listeners use an IPv6 address when the +// WithIPv6Enabled(true) option is passed into DevWorker. Other dev worker +// configurations are validated in TestDevWorker. func TestDevWorkerIpv6(t *testing.T) { - require, assert := require.New(t), assert.New(t) - // This test only validates that all listeners are utilizing an IPv6 address. - // Other dev worker configurations are validates in TestDevWorker. + require := require.New(t) + actual, err := DevWorker(WithIPv6Enabled(true)) require.NoError(err) - // expected an error here because we purposely did not provide a port number - // to allow randomly assigned port values + // Expected an error here because PublicAddr does not have a port. _, _, err = net.SplitHostPort(actual.Worker.PublicAddr) require.Error(err) - // assert the square brackets are removed from the worker ipv6 address and that the port value is empty + // util.SplitHostPort, however, can handle it when ports are missing. publicAddr, port, err := util.SplitHostPort(actual.Worker.PublicAddr) - require.NoError(err) - assert.Empty(port) + require.ErrorIs(err, util.ErrMissingPort) + require.Empty(port) + require.NotEmpty(t, publicAddr) + ip := net.ParseIP(publicAddr) - assert.NotNil(ip, "failed to parse worker public address") - assert.NotNil(ip.To16(), "worker public address is not IPv6 %s", actual.Worker.PublicAddr) + require.NotNil(ip, "failed to parse worker public address") + require.NotNil(ip.To16(), "worker public address is not IPv6 %s", actual.Worker.PublicAddr) require.NotEmpty(actual.Listeners) for _, l := range actual.Listeners { addr, _, err := util.SplitHostPort(l.Address) - require.NoError(err) + require.ErrorIs(err, util.ErrMissingPort) + require.NotEmpty(addr) + ip := net.ParseIP(addr) - assert.NotNil(ip, "failed to parse listener address for %v", l.Purpose) - assert.NotNil(ip.To16(), "failed to convert address to IPv6 for %v, found %v", l.Purpose, addr) + require.NotNil(ip, "failed to parse listener address for %v", l.Purpose) + require.NotNil(ip.To16(), "failed to convert address to IPv6 for %v, found %v", l.Purpose, addr) } } +// TestDevCombinedIpv6 validates that all listeners use an IPv6 address when the +// WithIPv6Enabled(true) option is passed into DevCombined. func TestDevCombinedIpv6(t *testing.T) { - require, assert := require.New(t), assert.New(t) - // This test only validates that all listeners are utilizing an IPv6 address. + require := require.New(t) + actual, err := DevCombined(WithIPv6Enabled(true)) require.NoError(err) - // expected an error here because we purposely did not provide a port number - // to allow randomly assigned port values for the worker and controller + // Expected to fail because PublicAddr does not have a port. _, _, err = net.SplitHostPort(actual.Worker.PublicAddr) require.Error(err) + // Expected to fail because PublicClusterAddr is not set. _, _, err = net.SplitHostPort(actual.Controller.PublicClusterAddr) require.Error(err) - // assert the square brackets are removed from the host ipv6 address and that the port value is empty + // util.SplitHostPort, however, can handle it when ports are missing. publicAddr, port, err := util.SplitHostPort(actual.Worker.PublicAddr) - require.NoError(err) - assert.Empty(port) + require.ErrorIs(err, util.ErrMissingPort) + require.Empty(port) + require.NotEmpty(publicAddr) + ip := net.ParseIP(publicAddr) - assert.NotNil(ip, "failed to parse worker public address") - assert.NotNil(ip.To16(), "worker public address is not IPv6 %s", actual.Worker.PublicAddr) + require.NotNil(ip, "failed to parse worker public address") + require.NotNil(ip.To16(), "worker public address is not IPv6 %s", actual.Worker.PublicAddr) - // assert the square brackets are removed from the controller ipv6 address and that the port value is empty + // Expected to fail because PublicClusterAddr is not set. publicAddr, port, err = util.SplitHostPort(actual.Controller.PublicClusterAddr) - require.NoError(err) - assert.Empty(port) - assert.Empty(publicAddr) + require.ErrorIs(err, util.ErrMissingPort) + require.Empty(port) + require.Empty(publicAddr) require.NotEmpty(actual.Listeners) for _, l := range actual.Listeners { addr, _, err := util.SplitHostPort(l.Address) - require.NoError(err) + require.ErrorIs(err, util.ErrMissingPort) + require.NotEmpty(addr) + ip := net.ParseIP(addr) - assert.NotNil(ip, "failed to parse listener address for %v", l.Purpose) - assert.NotNil(ip.To16(), "failed to convert address to IPv6 for %v, found %v", l.Purpose, addr) + require.NotNil(ip, "failed to parse listener address for %v", l.Purpose) + require.NotNil(ip.To16(), "failed to convert address to IPv6 for %v, found %v", l.Purpose, addr) } } @@ -1607,10 +1621,10 @@ func TestWorkerUpstreams(t *testing.T) { in: ` worker { name = "test" - initial_upstreams = ["[2001:4860:4860:0:0:0:0:8888]"] + initial_upstreams = ["2001:4860:4860:0:0:0:0:8888"] } `, - expWorkerUpstreams: []string{"[2001:4860:4860:0:0:0:0:8888]"}, + expWorkerUpstreams: []string{"2001:4860:4860::8888"}, expErr: false, }, { @@ -1618,10 +1632,10 @@ func TestWorkerUpstreams(t *testing.T) { in: ` worker { name = "test" - initial_upstreams = ["[2001:4860:4860::8888]"] + initial_upstreams = ["2001:4860:4860::8888"] } `, - expWorkerUpstreams: []string{"[2001:4860:4860::8888]"}, + expWorkerUpstreams: []string{"2001:4860:4860::8888"}, expErr: false, }, { @@ -2467,13 +2481,13 @@ func TestSetupControllerPublicClusterAddress(t *testing.T) { Listeners: []*listenerutil.ListenerConfig{}, }, Controller: &Controller{ - PublicClusterAddr: "[2001:4860:4860:0:0:0:0:8888]", + PublicClusterAddr: "2001:4860:4860:0:0:0:0:8888", }, }, inputFlagValue: "", expErr: false, expErrStr: "", - expPublicClusterAddress: "[2001:4860:4860:0:0:0:0:8888]:9201", + expPublicClusterAddress: "[2001:4860:4860::8888]:9201", }, { name: "setting public cluster address directly with ipv6:port", @@ -2488,7 +2502,7 @@ func TestSetupControllerPublicClusterAddress(t *testing.T) { inputFlagValue: "", expErr: false, expErrStr: "", - expPublicClusterAddress: "[2001:4860:4860:0:0:0:0:8888]:8080", + expPublicClusterAddress: "[2001:4860:4860::8888]:8080", }, { name: "setting public cluster address directly with abbreviated ipv6", @@ -2497,7 +2511,7 @@ func TestSetupControllerPublicClusterAddress(t *testing.T) { Listeners: []*listenerutil.ListenerConfig{}, }, Controller: &Controller{ - PublicClusterAddr: "[2001:4860:4860::8888]", + PublicClusterAddr: "2001:4860:4860::8888", }, }, inputFlagValue: "", @@ -2707,35 +2721,35 @@ func TestSetupControllerPublicClusterAddress(t *testing.T) { inputConfig: &Config{ SharedConfig: &configutil.SharedConfig{ Listeners: []*listenerutil.ListenerConfig{ - {Purpose: []string{"cluster"}, Address: "[2001:4860:4860:0:0:0:0:8888]"}, + {Purpose: []string{"cluster"}, Address: "2001:4860:4860:0:0:0:0:8888"}, }, }, Controller: &Controller{}, }, expErr: false, expErrStr: "", - expPublicClusterAddress: "[2001:4860:4860:0:0:0:0:8888]:9201", + expPublicClusterAddress: "[2001:4860:4860::8888]:9201", }, { name: "read address from listeners ipv6:port", inputConfig: &Config{ SharedConfig: &configutil.SharedConfig{ Listeners: []*listenerutil.ListenerConfig{ - {Purpose: []string{"cluster"}, Address: "[2001:4860:4860:0:0:0:0:8888]:8080"}, + {Purpose: []string{"cluster"}, Address: "[2001:4860:4860::8888]:8080"}, }, }, Controller: &Controller{}, }, expErr: false, expErrStr: "", - expPublicClusterAddress: "[2001:4860:4860:0:0:0:0:8888]:8080", + expPublicClusterAddress: "[2001:4860:4860::8888]:8080", }, { name: "read address from listeners abbreviated ipv6 only", inputConfig: &Config{ SharedConfig: &configutil.SharedConfig{ Listeners: []*listenerutil.ListenerConfig{ - {Purpose: []string{"cluster"}, Address: "[2001:4860:4860::8888]"}, + {Purpose: []string{"cluster"}, Address: "2001:4860:4860::8888"}, }, }, Controller: &Controller{}, @@ -2793,9 +2807,9 @@ func TestSetupControllerPublicClusterAddress(t *testing.T) { }, Controller: &Controller{}, }, - inputFlagValue: "abc::123", + inputFlagValue: "abc::123:::", expErr: true, - expErrStr: "Error splitting public cluster adddress host/port: address abc::123: too many colons in address", + expErrStr: "Error splitting public cluster adddress host/port: too many colons in address", expPublicClusterAddress: "", }, { @@ -2811,6 +2825,18 @@ func TestSetupControllerPublicClusterAddress(t *testing.T) { expErrStr: "Error parsing IP template on controller public cluster addr: unable to parse address template \"{{ somethingthatdoesntexist }}\": unable to parse template \"{{ somethingthatdoesntexist }}\": template: sockaddr.Parse:1: function \"somethingthatdoesntexist\" not defined", expPublicClusterAddress: "", }, + { + name: "unix listener", + inputConfig: &Config{ + SharedConfig: &configutil.SharedConfig{ + Listeners: []*listenerutil.ListenerConfig{ + {Address: "someaddr", Type: "unix", Purpose: []string{"cluster"}}, + }, + }, + Controller: &Controller{}, + }, + expPublicClusterAddress: "someaddr:9201", + }, } for _, tt := range tests { diff --git a/internal/credential/vault/testing.go b/internal/credential/vault/testing.go index fdd4641e69..c72c217065 100644 --- a/internal/credential/vault/testing.go +++ b/internal/credential/vault/testing.go @@ -548,7 +548,7 @@ func getDefaultTestOptions(t testing.TB) testOptions { vaultTLS: TestNoTLS, dockerNetwork: false, tokenPeriod: defaultPeriod, - serverCertHostNames: []string{"localhost"}, + serverCertHostNames: []string{"localhost", "127.0.0.1", "::1"}, } } diff --git a/internal/daemon/controller/handlers/accounts/account_service.go b/internal/daemon/controller/handlers/accounts/account_service.go index 4b913d3458..af3b493f00 100644 --- a/internal/daemon/controller/handlers/accounts/account_service.go +++ b/internal/daemon/controller/handlers/accounts/account_service.go @@ -30,6 +30,7 @@ import ( "github.com/hashicorp/boundary/internal/types/resource" "github.com/hashicorp/boundary/internal/types/subtypes" pb "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/accounts" + "github.com/hashicorp/go-secure-stdlib/parseutil" "golang.org/x/exp/maps" "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/structpb" @@ -691,7 +692,11 @@ func (s Service) createOidcInRepo(ctx context.Context, am auth.AuthMethod, item } attrs := item.GetOidcAccountAttributes() if attrs.GetIssuer() != "" { - u, err := url.Parse(attrs.GetIssuer()) + niss, err := parseutil.NormalizeAddr(attrs.GetIssuer()) + if err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to normalize issuer"), errors.WithCode(errors.InvalidParameter)) + } + u, err := url.Parse(niss) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to parse issuer"), errors.WithCode(errors.InvalidParameter)) } @@ -1321,8 +1326,10 @@ func validateCreateRequest(ctx context.Context, req *pbs.CreateAccountRequest) e if err != nil { badFields[issuerField] = fmt.Sprintf("Cannot be parsed as a url. %v", err) } - if trimmed := strings.TrimSuffix(strings.TrimSuffix(du.RawPath, "/"), "/.well-known/openid-configuration"); trimmed != "" { - badFields[issuerField] = "The path segment of the url should be empty." + if du != nil { + if trimmed := strings.TrimSuffix(strings.TrimSuffix(du.RawPath, "/"), "/.well-known/openid-configuration"); trimmed != "" { + badFields[issuerField] = "The path segment of the url should be empty." + } } } if attrs.GetFullName() != "" { diff --git a/internal/daemon/controller/handlers/accounts/account_service_test.go b/internal/daemon/controller/handlers/accounts/account_service_test.go index 246fc057dd..8688664569 100644 --- a/internal/daemon/controller/handlers/accounts/account_service_test.go +++ b/internal/daemon/controller/handlers/accounts/account_service_test.go @@ -2424,6 +2424,41 @@ func TestCreateOidc(t *testing.T) { }, }, }, + { + name: "Create a valid Account with IPv6 issuer address", + req: &pbs.CreateAccountRequest{ + Item: &pb.Account{ + AuthMethodId: am.GetPublicId(), + Name: &wrapperspb.StringValue{Value: "name-ipv6-iss"}, + Description: &wrapperspb.StringValue{Value: "desc-ipv6-iss"}, + Type: oidc.Subtype.String(), + Attrs: &pb.Account_OidcAccountAttributes{ + OidcAccountAttributes: &pb.OidcAccountAttributes{ + Issuer: "https://[2001:BEEF:0000:0000:0000:0000:0000:0001]:44344/v1/myissuer", + Subject: "valid-account-ipv6-iss", + }, + }, + }, + }, + res: &pbs.CreateAccountResponse{ + Uri: fmt.Sprintf("accounts/%s_", globals.OidcAccountPrefix), + Item: &pb.Account{ + AuthMethodId: am.GetPublicId(), + Name: &wrapperspb.StringValue{Value: "name-ipv6-iss"}, + Description: &wrapperspb.StringValue{Value: "desc-ipv6-iss"}, + Scope: &scopepb.ScopeInfo{Id: o.GetPublicId(), Type: scope.Org.String(), ParentScopeId: scope.Global.String()}, + Version: 1, + Type: oidc.Subtype.String(), + Attrs: &pb.Account_OidcAccountAttributes{ + OidcAccountAttributes: &pb.OidcAccountAttributes{ + Subject: "valid-account-ipv6-iss", + Issuer: "https://[2001:beef::1]:44344/v1/myissuer", + }, + }, + AuthorizedActions: oidcAuthorizedActions, + }, + }, + }, { name: "Create a valid Account without type defined", req: &pbs.CreateAccountRequest{ @@ -2564,6 +2599,25 @@ func TestCreateOidc(t *testing.T) { res: nil, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Malformed issuer url", + req: &pbs.CreateAccountRequest{ + Item: &pb.Account{ + AuthMethodId: am.GetPublicId(), + Name: &wrapperspb.StringValue{Value: "name-ipv6-iss"}, + Description: &wrapperspb.StringValue{Value: "desc-ipv6-iss"}, + Type: oidc.Subtype.String(), + Attrs: &pb.Account_OidcAccountAttributes{ + OidcAccountAttributes: &pb.OidcAccountAttributes{ + Issuer: "https://2000:0005::0001]", // missing '[' after https:// + Subject: "valid-account-ipv6-iss", + }, + }, + }, + }, + res: nil, + err: handlers.ApiErrorWithCodeAndMessage(codes.InvalidArgument, `Error: "Error in provided request.", Details: {{name: "attributes.issuer", desc: "Cannot be parsed as a url. parse \"https://2000:0005::0001]\": invalid port \":0001]\" after host"}}`), + }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { diff --git a/internal/daemon/controller/handlers/authmethods/authmethod_service.go b/internal/daemon/controller/handlers/authmethods/authmethod_service.go index 32f86e8523..684e255c90 100644 --- a/internal/daemon/controller/handlers/authmethods/authmethod_service.go +++ b/internal/daemon/controller/handlers/authmethods/authmethod_service.go @@ -1061,8 +1061,12 @@ func validateCreateRequest(ctx context.Context, req *pbs.CreateAuthMethodRequest if err != nil { badFields[issuerField] = fmt.Sprintf("Cannot be parsed as a url. %v", err) } - if !strutil.StrListContains([]string{"http", "https"}, iss.Scheme) { - badFields[issuerField] = fmt.Sprintf("Must have schema %q or %q specified", "http", "https") + if iss != nil { + if !strutil.StrListContains([]string{"http", "https"}, iss.Scheme) { + badFields[issuerField] = fmt.Sprintf("Must have schema %q or %q specified", "http", "https") + } + } else { + badFields[issuerField] = "Cannot be parsed as a url" } } if attrs.GetDisableDiscoveredConfigValidation() { diff --git a/internal/daemon/controller/handlers/authmethods/oidc.go b/internal/daemon/controller/handlers/authmethods/oidc.go index 919aef62ed..428dfde4a0 100644 --- a/internal/daemon/controller/handlers/authmethods/oidc.go +++ b/internal/daemon/controller/handlers/authmethods/oidc.go @@ -19,6 +19,7 @@ import ( pbs "github.com/hashicorp/boundary/internal/gen/controller/api/services" "github.com/hashicorp/boundary/internal/types/action" pb "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/authmethods" + "github.com/hashicorp/go-secure-stdlib/parseutil" "google.golang.org/grpc/codes" ) @@ -419,6 +420,10 @@ func toStorageOidcAuthMethod(ctx context.Context, scopeId string, in *pb.AuthMet // Strip off everything after and including ".well-known/openid-configuration" // but leave the "/" attached to the end. iss = strings.SplitN(iss, ".well-known/", 2)[0] + iss, err := parseutil.NormalizeAddr(iss) + if err != nil { + return nil, false, false, errors.Wrap(ctx, err, op, errors.WithMsg("cannot normalize issuer"), errors.WithCode(errors.InvalidParameter)) + } issuer, err := url.Parse(iss) if err != nil { return nil, false, false, errors.Wrap(ctx, err, op, errors.WithMsg("cannot parse issuer"), errors.WithCode(errors.InvalidParameter)) @@ -426,6 +431,10 @@ func toStorageOidcAuthMethod(ctx context.Context, scopeId string, in *pb.AuthMet opts = append(opts, oidc.WithIssuer(issuer)) } if apiUrl := strings.TrimSpace(attrs.GetApiUrlPrefix().GetValue()); apiUrl != "" { + apiUrl, err := parseutil.NormalizeAddr(apiUrl) + if err != nil { + return nil, false, false, errors.Wrap(ctx, err, op, errors.WithMsg("cannot normalize api_url_prefix"), errors.WithCode(errors.InvalidParameter)) + } apiU, err := url.Parse(apiUrl) if err != nil { return nil, false, false, errors.Wrap(ctx, err, op, errors.WithMsg("cannot parse api_url_prefix"), errors.WithCode(errors.InvalidParameter)) diff --git a/internal/daemon/controller/handlers/authmethods/oidc_test.go b/internal/daemon/controller/handlers/authmethods/oidc_test.go index fc79db52a1..bf458d4e32 100644 --- a/internal/daemon/controller/handlers/authmethods/oidc_test.go +++ b/internal/daemon/controller/handlers/authmethods/oidc_test.go @@ -454,6 +454,40 @@ func TestUpdate_OIDC(t *testing.T) { }, }, }, + { + name: "Update Issuer IPv6", + req: &pbs.UpdateAuthMethodRequest{ + UpdateMask: &field_mask.FieldMask{ + Paths: []string{"attributes.issuer"}, + }, + Item: &pb.AuthMethod{ + Attrs: func() *pb.AuthMethod_OidcAuthMethodsAttributes { + f := proto.Clone(defaultAttributes.OidcAuthMethodsAttributes).(*pb.OidcAuthMethodAttributes) + f.Issuer = wrapperspb.String("https://[2001:BEEF:0000:0000:0000:0000:0000:0001]:44344/v1/myissuer/.well-known/openid-configuration") + f.DisableDiscoveredConfigValidation = true + return &pb.AuthMethod_OidcAuthMethodsAttributes{OidcAuthMethodsAttributes: f} + }(), + }, + }, + res: &pbs.UpdateAuthMethodResponse{ + Item: &pb.AuthMethod{ + ScopeId: o.GetPublicId(), + Name: &wrapperspb.StringValue{Value: "default"}, + Description: &wrapperspb.StringValue{Value: "default"}, + Type: oidc.Subtype.String(), + Attrs: func() *pb.AuthMethod_OidcAuthMethodsAttributes { + f := proto.Clone(defaultReadAttributes.OidcAuthMethodsAttributes).(*pb.OidcAuthMethodAttributes) + f.Issuer = wrapperspb.String("https://[2001:beef::1]:44344/v1/myissuer/") + f.DisableDiscoveredConfigValidation = true + return &pb.AuthMethod_OidcAuthMethodsAttributes{OidcAuthMethodsAttributes: f} + }(), + + Scope: defaultScopeInfo, + AuthorizedActions: oidcAuthorizedActions, + AuthorizedCollectionActions: authorizedCollectionActions, + }, + }, + }, { name: "invalid-issuer-port", req: &pbs.UpdateAuthMethodRequest{ @@ -855,6 +889,38 @@ func TestUpdate_OIDC(t *testing.T) { }, }, }, + { + name: "Change Api Url Prefix IPv6", + req: &pbs.UpdateAuthMethodRequest{ + UpdateMask: &field_mask.FieldMask{ + Paths: []string{"attributes.api_url_prefix"}, + }, + Item: &pb.AuthMethod{ + Attrs: &pb.AuthMethod_OidcAuthMethodsAttributes{ + OidcAuthMethodsAttributes: &pb.OidcAuthMethodAttributes{ + ApiUrlPrefix: wrapperspb.String("https://[2001:BEEF:0000:0000:0000:0000:0000:0001]:44344/path"), + }, + }, + }, + }, + res: &pbs.UpdateAuthMethodResponse{ + Item: &pb.AuthMethod{ + ScopeId: o.GetPublicId(), + Name: &wrapperspb.StringValue{Value: "default"}, + Description: &wrapperspb.StringValue{Value: "default"}, + Type: oidc.Subtype.String(), + Attrs: func() *pb.AuthMethod_OidcAuthMethodsAttributes { + f := proto.Clone(defaultReadAttributes.OidcAuthMethodsAttributes).(*pb.OidcAuthMethodAttributes) + f.ApiUrlPrefix = wrapperspb.String("https://[2001:beef::1]:44344/path") + f.CallbackUrl = "https://[2001:beef::1]:44344/path/v1/auth-methods/oidc:authenticate:callback" + return &pb.AuthMethod_OidcAuthMethodsAttributes{OidcAuthMethodsAttributes: f} + }(), + Scope: defaultScopeInfo, + AuthorizedActions: oidcAuthorizedActions, + AuthorizedCollectionActions: authorizedCollectionActions, + }, + }, + }, { name: "Change Allowed Audiences", req: &pbs.UpdateAuthMethodRequest{ @@ -1127,9 +1193,7 @@ func TestUpdate_OIDC(t *testing.T) { if got.Item.GetOidcAuthMethodsAttributes().CallbackUrl != "" { exp := tc.res.Item.GetOidcAuthMethodsAttributes().GetCallbackUrl() gVal := got.Item.GetOidcAuthMethodsAttributes().GetCallbackUrl() - matches, err := regexp.MatchString(exp, gVal) - require.NoError(err) - assert.True(matches, "%q doesn't match %q", gVal, exp) + assert.Equal(exp, gVal, "%q doesn't match %q", exp, gVal) } assert.EqualValues(3, got.Item.Version) diff --git a/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go b/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go index f47effbdb6..8841f6f566 100644 --- a/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go +++ b/internal/daemon/controller/handlers/credentialstores/credentialstore_service.go @@ -32,6 +32,7 @@ import ( "github.com/hashicorp/boundary/internal/types/subtypes" pb "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/credentialstores" "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/scopes" + "github.com/hashicorp/go-secure-stdlib/parseutil" "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/wrapperspb" @@ -834,6 +835,13 @@ func toStorageVaultStore(ctx context.Context, scopeId string, in *pb.CredentialS if attrs.GetWorkerFilter().GetValue() != "" { opts = append(opts, vault.WithWorkerFilter(attrs.GetWorkerFilter().GetValue())) } + if attrs.GetAddress().GetValue() != "" { + addr, err := parseutil.NormalizeAddr(attrs.GetAddress().GetValue()) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + attrs.Address = wrapperspb.String(addr) + } // TODO (ICU-1478 and ICU-1479): Update the vault's interface around ca cert to match oidc's, // accepting x509.Certificate instead of []byte diff --git a/internal/daemon/controller/handlers/credentialstores/credentialstore_service_test.go b/internal/daemon/controller/handlers/credentialstores/credentialstore_service_test.go index 0c84b5b153..7ac17010a2 100644 --- a/internal/daemon/controller/handlers/credentialstores/credentialstore_service_test.go +++ b/internal/daemon/controller/handlers/credentialstores/credentialstore_service_test.go @@ -10,6 +10,7 @@ import ( "crypto/rand" "encoding/base64" "fmt" + "net/url" "slices" "strings" "testing" @@ -647,6 +648,64 @@ func TestCreateVault(t *testing.T) { }, }, }, + { + name: "Create a valid vault CredentialStore IPv6 Address", + req: &pbs.CreateCredentialStoreRequest{Item: &pb.CredentialStore{ + ScopeId: prj.GetPublicId(), + Name: &wrapperspb.StringValue{Value: "name-ipv6"}, + Description: &wrapperspb.StringValue{Value: "desc-ipv6"}, + Type: vault.Subtype.String(), + Attrs: &pb.CredentialStore_VaultCredentialStoreAttributes{ + VaultCredentialStoreAttributes: &pb.VaultCredentialStoreAttributes{ + Address: func() *wrapperspb.StringValue { + u, err := url.Parse(v.Addr) + require.NoError(t, err) + require.NotNil(t, u) + require.NotEmpty(t, u.Port()) + require.NotEmpty(t, u.Scheme) + + return wrapperspb.String(fmt.Sprintf("%s://[0000:0000:0000:0000:0000:0000:0000:0001]:%s", u.Scheme, u.Port())) + }(), + Token: wrapperspb.String(newToken()), + CaCert: wrapperspb.String(string(v.CaCert)), + ClientCertificate: wrapperspb.String(string(v.ClientCert)), + ClientCertificateKey: wrapperspb.String(string(v.ClientKey)), + }, + }, + }}, + idPrefix: globals.VaultCredentialStorePrefix + "_", + res: &pbs.CreateCredentialStoreResponse{ + Uri: fmt.Sprintf("credential-stores/%s_", globals.VaultCredentialStorePrefix), + Item: &pb.CredentialStore{ + ScopeId: prj.GetPublicId(), + Name: &wrapperspb.StringValue{Value: "name-ipv6"}, + Description: &wrapperspb.StringValue{Value: "desc-ipv6"}, + Scope: &scopepb.ScopeInfo{Id: prj.GetPublicId(), Type: prj.GetType(), ParentScopeId: prj.GetParentId()}, + Version: 1, + Type: vault.Subtype.String(), + Attrs: &pb.CredentialStore_VaultCredentialStoreAttributes{ + VaultCredentialStoreAttributes: &pb.VaultCredentialStoreAttributes{ + CaCert: wrapperspb.String(string(v.CaCert)), + Address: func() *wrapperspb.StringValue { + u, err := url.Parse(v.Addr) + require.NoError(t, err) + require.NotNil(t, u) + require.NotEmpty(t, u.Port()) + require.NotEmpty(t, u.Scheme) + + return wrapperspb.String(fmt.Sprintf("%s://[::1]:%s", u.Scheme, u.Port())) + }(), + TokenHmac: "", + TokenStatus: "current", + ClientCertificate: wrapperspb.String(string(v.ClientCert)), + ClientCertificateKeyHmac: "", + }, + }, + AuthorizedActions: testAuthorizedActions, + AuthorizedCollectionActions: testAuthorizedVaultCollectionActions, + }, + }, + }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { diff --git a/internal/daemon/controller/handlers/hosts/host_service.go b/internal/daemon/controller/handlers/hosts/host_service.go index 9002c9929f..c856a2fc8d 100644 --- a/internal/daemon/controller/handlers/hosts/host_service.go +++ b/internal/daemon/controller/handlers/hosts/host_service.go @@ -6,7 +6,6 @@ package hosts import ( "context" "fmt" - "net" "strings" "github.com/hashicorp/boundary/globals" @@ -773,15 +772,13 @@ func validateCreateRequest(req *pbs.CreateHostRequest) error { len(attrs.GetAddress().GetValue()) > static.MaxHostAddressLength { badFields[globals.AttributesAddressField] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) } else { - _, _, err := net.SplitHostPort(attrs.GetAddress().GetValue()) - switch { - case err == nil: - badFields[globals.AttributesAddressField] = "Address for static hosts does not support a port." - case strings.Contains(err.Error(), globals.MissingPortErrStr): - // Bare hostname, which we want - default: + _, port, err := util.SplitHostPort(attrs.GetAddress().GetValue()) + if err != nil && !errors.Is(err, util.ErrMissingPort) { badFields[globals.AttributesAddressField] = fmt.Sprintf("Error parsing address: %v.", err) } + if port != "" { + badFields[globals.AttributesAddressField] = "Address for static hosts does not support a port." + } } } case hostplugin.Subtype: @@ -810,15 +807,13 @@ func validateUpdateRequest(req *pbs.UpdateHostRequest) error { len(strings.TrimSpace(attrs.GetAddress().GetValue())) > static.MaxHostAddressLength { badFields[globals.AttributesAddressField] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) } else { - _, _, err := net.SplitHostPort(attrs.GetAddress().GetValue()) - switch { - case err == nil: - badFields[globals.AttributesAddressField] = "Address for static hosts does not support a port." - case strings.Contains(err.Error(), globals.MissingPortErrStr): - // Bare hostname, which we want - default: + _, port, err := util.SplitHostPort(attrs.GetAddress().GetValue()) + if err != nil && !errors.Is(err, util.ErrMissingPort) { badFields[globals.AttributesAddressField] = fmt.Sprintf("Error parsing address: %v.", err) } + if port != "" { + badFields[globals.AttributesAddressField] = "Address for static hosts does not support a port." + } } } } diff --git a/internal/daemon/controller/handlers/hosts/host_service_test.go b/internal/daemon/controller/handlers/hosts/host_service_test.go index 6a39bf00f0..477f08fd63 100644 --- a/internal/daemon/controller/handlers/hosts/host_service_test.go +++ b/internal/daemon/controller/handlers/hosts/host_service_test.go @@ -1284,6 +1284,36 @@ func TestCreate(t *testing.T) { }, }, }, + { + name: "Create a valid Host with IPv6 address", + req: &pbs.CreateHostRequest{Item: &pb.Host{ + HostCatalogId: hc.GetPublicId(), + Name: &wrappers.StringValue{Value: "name-ipv6"}, + Description: &wrappers.StringValue{Value: "desc-ipv6"}, + Type: "static", + Attrs: &pb.Host_StaticHostAttributes{ + StaticHostAttributes: &pb.StaticHostAttributes{ + Address: wrapperspb.String("2001:BEEF:0000:0000:0000:0000:0000:0001"), + }, + }, + }}, + res: &pbs.CreateHostResponse{ + Uri: fmt.Sprintf("hosts/%s_", globals.StaticHostPrefix), + Item: &pb.Host{ + HostCatalogId: hc.GetPublicId(), + Scope: &scopes.ScopeInfo{Id: proj.GetPublicId(), Type: scope.Project.String(), ParentScopeId: org.GetPublicId()}, + Name: &wrappers.StringValue{Value: "name-ipv6"}, + Description: &wrappers.StringValue{Value: "desc-ipv6"}, + Type: "static", + Attrs: &pb.Host_StaticHostAttributes{ + StaticHostAttributes: &pb.StaticHostAttributes{ + Address: wrapperspb.String("2001:beef::1"), + }, + }, + AuthorizedActions: testAuthorizedActions[static.Subtype], + }, + }, + }, { name: "no-attributes", req: &pbs.CreateHostRequest{Item: &pb.Host{ @@ -1549,6 +1579,41 @@ func TestUpdate_Static(t *testing.T) { }, }, }, + { + name: "Update address", + req: &pbs.UpdateHostRequest{ + Id: h.GetPublicId(), + UpdateMask: &field_mask.FieldMask{ + Paths: []string{globals.AttributesAddressField}, + }, + Item: &pb.Host{ + Attrs: &pb.Host_StaticHostAttributes{ + StaticHostAttributes: &pb.StaticHostAttributes{ + Address: wrapperspb.String("2001:BEEF:0000:0000:0000:0000:0000:0001"), + }, + }, + Type: "static", + }, + }, + res: &pbs.UpdateHostResponse{ + Item: &pb.Host{ + HostCatalogId: hc.GetPublicId(), + Id: h.GetPublicId(), + Scope: &scopes.ScopeInfo{Id: proj.GetPublicId(), Type: scope.Project.String(), ParentScopeId: org.GetPublicId()}, + Name: &wrappers.StringValue{Value: "default"}, + Description: &wrappers.StringValue{Value: "default"}, + CreatedTime: h.GetCreateTime().GetTimestamp(), + Type: "static", + Attrs: &pb.Host_StaticHostAttributes{ + StaticHostAttributes: &pb.StaticHostAttributes{ + Address: wrapperspb.String("2001:beef::1"), + }, + }, + AuthorizedActions: testAuthorizedActions[static.Subtype], + HostSetIds: []string{s.GetPublicId()}, + }, + }, + }, { name: "Multiple Paths in single string", req: &pbs.UpdateHostRequest{ diff --git a/internal/daemon/controller/handlers/targets/target_service.go b/internal/daemon/controller/handlers/targets/target_service.go index e5ec643043..bc620f7ae2 100644 --- a/internal/daemon/controller/handlers/targets/target_service.go +++ b/internal/daemon/controller/handlers/targets/target_service.go @@ -910,9 +910,9 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession "No host was discovered after checking target address and host sources.") } - // Ensure we don't have a port from the address - _, err = util.ParseAddress(ctx, h) - if err != nil { + // Ensure we don't have a port from the address and that any ipv6 addresses + // are formatted properly + if h, err = util.ParseAddress(ctx, h); err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error when parsing the chosen endpoint host address")) } @@ -1817,15 +1817,13 @@ func validateCreateRequest(req *pbs.CreateTargetRequest) error { } } if address := item.GetAddress(); address != nil { - if len(address.GetValue()) < static.MinHostAddressLength || - len(address.GetValue()) > static.MaxHostAddressLength { - badFields[globals.AddressField] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) - } - _, _, err := net.SplitHostPort(address.GetValue()) + _, err := util.ParseAddress(context.Background(), address.GetValue()) switch { case err == nil: + case errors.Is(err, util.ErrInvalidAddressLength): + badFields[globals.AddressField] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) + case errors.Is(err, util.ErrInvalidAddressContainsPort): badFields[globals.AddressField] = "Address does not support a port." - case strings.Contains(err.Error(), globals.MissingPortErrStr): default: badFields[globals.AddressField] = fmt.Sprintf("Error parsing address: %v.", err) } @@ -1897,15 +1895,13 @@ func validateUpdateRequest(req *pbs.UpdateTargetRequest) error { } } if address := item.GetAddress(); address != nil { - if len(address.GetValue()) < static.MinHostAddressLength || - len(address.GetValue()) > static.MaxHostAddressLength { - badFields[globals.AddressField] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) - } - _, _, err := net.SplitHostPort(address.GetValue()) + _, err := util.ParseAddress(context.Background(), address.GetValue()) switch { case err == nil: + case errors.Is(err, util.ErrInvalidAddressLength): + badFields[globals.AddressField] = fmt.Sprintf("Address length must be between %d and %d characters.", static.MinHostAddressLength, static.MaxHostAddressLength) + case errors.Is(err, util.ErrInvalidAddressContainsPort): badFields[globals.AddressField] = "Address does not support a port." - case strings.Contains(err.Error(), globals.MissingPortErrStr): default: badFields[globals.AddressField] = fmt.Sprintf("Error parsing address: %v.", err) } diff --git a/internal/daemon/controller/handlers/targets/tcp/target_service_test.go b/internal/daemon/controller/handlers/targets/tcp/target_service_test.go index 7079a10ded..d6c581c257 100644 --- a/internal/daemon/controller/handlers/targets/tcp/target_service_test.go +++ b/internal/daemon/controller/handlers/targets/tcp/target_service_test.go @@ -9,7 +9,6 @@ import ( "encoding/base64" "encoding/json" "encoding/pem" - "errors" "fmt" "path" "slices" @@ -34,6 +33,7 @@ import ( "github.com/hashicorp/boundary/internal/daemon/controller/handlers/credentials" "github.com/hashicorp/boundary/internal/daemon/controller/handlers/targets" "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/event" pbs "github.com/hashicorp/boundary/internal/gen/controller/api/services" authpb "github.com/hashicorp/boundary/internal/gen/controller/auth" @@ -1595,6 +1595,82 @@ func TestCreate(t *testing.T) { res: nil, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Invalid address ipv6 missing segment", + req: &pbs.CreateTargetRequest{Item: &pb.Target{ + ScopeId: proj.GetPublicId(), + Name: wrapperspb.String("name1"), + Description: wrapperspb.String("desc"), + Type: tcp.Subtype.String(), + Attrs: &pb.Target_TcpTargetAttributes{ + TcpTargetAttributes: &pb.TcpTargetAttributes{ + DefaultPort: wrapperspb.UInt32(2), + DefaultClientPort: wrapperspb.UInt32(3), + }, + }, + EgressWorkerFilter: wrapperspb.String(`type == "bar"`), + Address: wrapperspb.String("2001:BEEF:0:0:1:0:0001"), + }}, + res: nil, + errStr: "Error parsing address: host contains an invalid IPv6 literal.", + }, + { + name: "Invalid address ipv6 has brackets", + req: &pbs.CreateTargetRequest{Item: &pb.Target{ + ScopeId: proj.GetPublicId(), + Name: wrapperspb.String("name2"), + Description: wrapperspb.String("desc"), + Type: tcp.Subtype.String(), + Attrs: &pb.Target_TcpTargetAttributes{ + TcpTargetAttributes: &pb.TcpTargetAttributes{ + DefaultPort: wrapperspb.UInt32(2), + DefaultClientPort: wrapperspb.UInt32(3), + }, + }, + EgressWorkerFilter: wrapperspb.String(`type == "bar"`), + Address: wrapperspb.String("[2001:BEEF:0:0:0:1:0:0001]"), + }}, + res: nil, + errStr: "Error parsing address: address cannot be encapsulated by brackets", + }, + { + name: "Create a valid target with ipv6 address", + req: &pbs.CreateTargetRequest{Item: &pb.Target{ + ScopeId: proj.GetPublicId(), + Name: wrapperspb.String("valid ipv6"), + Description: wrapperspb.String("desc"), + Type: tcp.Subtype.String(), + Attrs: &pb.Target_TcpTargetAttributes{ + TcpTargetAttributes: &pb.TcpTargetAttributes{ + DefaultPort: wrapperspb.UInt32(2), + DefaultClientPort: wrapperspb.UInt32(3), + }, + }, + EgressWorkerFilter: wrapperspb.String(`type == "bar"`), + Address: wrapperspb.String("2001:BEEF:0:0:0:1:0:0001"), + }}, + res: &pbs.CreateTargetResponse{ + Uri: fmt.Sprintf("targets/%s_", globals.TcpTargetPrefix), + Item: &pb.Target{ + ScopeId: proj.GetPublicId(), + Scope: &scopes.ScopeInfo{Id: proj.GetPublicId(), Type: scope.Project.String(), ParentScopeId: org.GetPublicId()}, + Name: wrapperspb.String("valid ipv6"), + Description: wrapperspb.String("desc"), + Type: tcp.Subtype.String(), + Attrs: &pb.Target_TcpTargetAttributes{ + TcpTargetAttributes: &pb.TcpTargetAttributes{ + DefaultPort: wrapperspb.UInt32(2), + DefaultClientPort: wrapperspb.UInt32(3), + }, + }, + SessionMaxSeconds: wrapperspb.UInt32(28800), + SessionConnectionLimit: wrapperspb.Int32(-1), + AuthorizedActions: testAuthorizedActions, + EgressWorkerFilter: wrapperspb.String(`type == "bar"`), + Address: wrapperspb.String("2001:beef::1:0:1"), + }, + }, + }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { @@ -1618,7 +1694,7 @@ func TestCreate(t *testing.T) { assert.True(errors.Is(gErr, tc.err), "CreateTarget(%+v) got error %v, wanted %v", tc.req, gErr, tc.err) } if tc.errStr != "" { - assert.ErrorContains(gErr, tc.errStr) + assert.ErrorContains(gErr, tc.errStr, "CreateTarget(%+v) got error %v, wanted %v", tc.req, gErr, tc.err) } } else { assert.Nil(gErr, "Unexpected err: %v", gErr) @@ -2156,6 +2232,130 @@ func TestUpdate(t *testing.T) { res: nil, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + defer resetTarget() + assert, require := assert.New(t), require.New(t) + tc.req.Item.Version = tar.GetVersion() + + req := proto.Clone(toMerge).(*pbs.UpdateTargetRequest) + proto.Merge(req, tc.req) + + requestInfo := authpb.RequestInfo{ + TokenFormat: uint32(auth.AuthTokenTypeBearer), + PublicId: at.GetPublicId(), + Token: at.GetToken(), + } + requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{}) + ctx := auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo) + got, gErr := tested.UpdateTarget(ctx, req) + if tc.err != nil { + require.Error(gErr) + assert.True(errors.Is(gErr, tc.err), "UpdateTarget(%+v) got error %v, wanted %v", req, gErr, tc.err) + return + } + require.NoError(gErr) + + if got != nil { + assert.NotNilf(tc.res, "Expected UpdateHost response to be nil, but was %v", got) + gotUpdateTime := got.GetItem().GetUpdatedTime().AsTime() + // Verify it is a set updated after it was created + // TODO: This is currently failing. + assert.True(gotUpdateTime.After(hCreated), "Updated target should have been updated after it's creation. Was updated %v, which is after %v", gotUpdateTime, hCreated) + + // Clear all values which are hard to compare against. + got.Item.UpdatedTime, tc.res.Item.UpdatedTime = nil, nil + } + if tc.res != nil { + tc.res.Item.Version = tc.req.Item.Version + 1 + } + assert.Empty(cmp.Diff( + got, + tc.res, + protocmp.Transform(), + cmpopts.SortSlices(func(a, b string) bool { + return a < b + }), + ), "UpdateTarget(%q) got response %q, wanted %q", req, got, tc.res) + }) + } + // Reset worker filter funcs + targets.ValidateIngressWorkerFilterFn = validateIngressFn +} + +func TestUpdateAddress(t *testing.T) { + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + wrapper := db.TestWrapper(t) + kms := kms.TestKms(t, conn, wrapper) + + rw := db.New(conn) + iamRepo := iam.TestRepo(t, conn, wrapper) + iamRepoFn := func() (*iam.Repository, error) { + return iamRepo, nil + } + tokenRepoFn := func() (*authtoken.Repository, error) { + return authtoken.NewRepository(ctx, rw, rw, kms) + } + serversRepoFn := func() (*server.Repository, error) { + return server.NewRepository(ctx, rw, rw, kms) + } + + org, proj := iam.TestScopes(t, iamRepo) + at := authtoken.TestAuthToken(t, conn, kms, org.GetPublicId()) + r := iam.TestRole(t, conn, proj.GetPublicId()) + _ = iam.TestUserRole(t, conn, r.GetPublicId(), at.GetIamUserId()) + _ = iam.TestRoleGrant(t, conn, r.GetPublicId(), "ids=*;type=*;actions=*") + + repoFn := func(o ...target.Option) (*target.Repository, error) { + return target.NewRepository(ctx, rw, rw, kms) + } + repo, err := repoFn() + require.NoError(t, err, "Couldn't create new target repo.") + + ttar, err := target.New(ctx, tcp.Subtype, proj.GetPublicId(), + target.WithName("default"), + target.WithDescription("default"), + target.WithSessionMaxSeconds(1), + target.WithSessionConnectionLimit(1), + target.WithDefaultPort(2), + target.WithDefaultClientPort(3), + target.WithAddress("8.8.8.8"), + ) + require.NoError(t, err) + tar, err := repo.CreateTarget(context.Background(), ttar) + require.NoError(t, err) + + resetTarget := func() { + itar, err := repo.LookupTarget(context.Background(), tar.GetPublicId()) + require.NoError(t, err) + + tar, _, err = repo.UpdateTarget(context.Background(), tar, itar.GetVersion(), + []string{"Name", "Description", "SessionMaxSeconds", "SessionConnectionLimit", "DefaultPort", "DefaultClientPort"}) + require.NoError(t, err, "Failed to reset target.") + } + + hCreated := tar.GetCreateTime().GetTimestamp().AsTime() + toMerge := &pbs.UpdateTargetRequest{ + Id: tar.GetPublicId(), + } + + tested, err := testService(t, context.Background(), conn, kms, wrapper) + require.NoError(t, err, "Failed to create a new host set service.") + + // Ensure we are using the OSS worker filter functions. This prevents us + // from running tests in parallel. + server.TestUseCommunityFilterWorkersFn(t) + validateIngressFn := targets.ValidateIngressWorkerFilterFn + targets.ValidateIngressWorkerFilterFn = targets.IngressWorkerFilterUnsupported + + cases := []struct { + name string + req *pbs.UpdateTargetRequest + res *pbs.UpdateTargetResponse + err string + }{ { name: "Invalid address length", req: &pbs.UpdateTargetRequest{ @@ -2167,7 +2367,7 @@ func TestUpdate(t *testing.T) { }, }, res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: "Address length must be between 3 and 255 characters", }, { name: "Invalid address w/ port", @@ -2180,7 +2380,7 @@ func TestUpdate(t *testing.T) { }, }, res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: "Address does not support a port", }, { name: "Invalid address not parsable", @@ -2193,7 +2393,63 @@ func TestUpdate(t *testing.T) { }, }, res: nil, - err: handlers.ApiErrorWithCode(codes.InvalidArgument), + err: "Error parsing address: failed to parse address.", + }, + { + name: "Update address valid ipv6", + req: &pbs.UpdateTargetRequest{ + UpdateMask: &field_mask.FieldMask{ + Paths: []string{"address"}, + }, + Item: &pb.Target{ + Address: wrapperspb.String("2001:BEEF:0:0:0:1:0:0001"), + }, + }, + res: &pbs.UpdateTargetResponse{ + Item: &pb.Target{ + Id: tar.GetPublicId(), + ScopeId: tar.GetProjectId(), + Scope: &scopes.ScopeInfo{Id: proj.GetPublicId(), Type: scope.Project.String(), ParentScopeId: org.GetPublicId()}, + Name: wrapperspb.String("default"), + Description: wrapperspb.String("default"), + CreatedTime: tar.GetCreateTime().GetTimestamp(), + Attrs: &pb.Target_TcpTargetAttributes{ + TcpTargetAttributes: &pb.TcpTargetAttributes{ + DefaultPort: wrapperspb.UInt32(2), + DefaultClientPort: wrapperspb.UInt32(3), + }, + }, + Type: tcp.Subtype.String(), + SessionMaxSeconds: wrapperspb.UInt32(tar.GetSessionMaxSeconds()), + SessionConnectionLimit: wrapperspb.Int32(tar.GetSessionConnectionLimit()), + AuthorizedActions: testAuthorizedActions, + Address: wrapperspb.String("2001:beef::1:0:1"), + }, + }, + }, + { + name: "Update address invalid ipv6 with brackets", + req: &pbs.UpdateTargetRequest{ + UpdateMask: &field_mask.FieldMask{ + Paths: []string{"address"}, + }, + Item: &pb.Target{ + Address: wrapperspb.String("[2001:BEEF:0:0:0:1:0:0001]"), + }, + }, + err: "Error parsing address: address cannot be encapsulated by brackets.", + }, + { + name: "Update address invalid ipv6 missing segment", + req: &pbs.UpdateTargetRequest{ + UpdateMask: &field_mask.FieldMask{ + Paths: []string{"address"}, + }, + Item: &pb.Target{ + Address: wrapperspb.String("2001:BEEF:0:0:1:0:0001"), + }, + }, + err: "Error parsing address: host contains an invalid IPv6 literal.", }, } for _, tc := range cases { @@ -2213,11 +2469,12 @@ func TestUpdate(t *testing.T) { requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{}) ctx := auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo) got, gErr := tested.UpdateTarget(ctx, req) - if tc.err != nil { + if tc.err != "" { require.Error(gErr) - assert.True(errors.Is(gErr, tc.err), "UpdateTarget(%+v) got error %v, wanted %v", req, gErr, tc.err) + assert.ErrorContainsf(gErr, tc.err, "UpdateTarget(%+v) got error %v, wanted %v", req, gErr, tc.err) return } + require.NoError(gErr) if got != nil { diff --git a/internal/daemon/worker/controller_connection.go b/internal/daemon/worker/controller_connection.go index d571759508..bb4b8e74d4 100644 --- a/internal/daemon/worker/controller_connection.go +++ b/internal/daemon/worker/controller_connection.go @@ -52,7 +52,7 @@ func (w *Worker) StartControllerConnections() error { initialAddrs = append(initialAddrs, addr) default: host, port, err := util.SplitHostPort(addr) - if err != nil { + if err != nil && !errors.Is(err, util.ErrMissingPort) { return fmt.Errorf("error parsing upstream address: %w", err) } if port == "" { diff --git a/internal/daemon/worker/handler.go b/internal/daemon/worker/handler.go index a7b2b12463..bce8628c65 100644 --- a/internal/daemon/worker/handler.go +++ b/internal/daemon/worker/handler.go @@ -294,7 +294,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa ConnectionId: acResp.GetConnectionId(), ClientTcpAddress: clientAddr.IP.String(), ClientTcpPort: uint32(clientAddr.Port), - EndpointTcpAddress: endpointAddr.Ip(), + EndpointTcpAddress: endpointAddr.Ip(), // endpointAddr.ip is assigned via net.IP and therefore should already be formatted correctly EndpointTcpPort: endpointAddr.Port(), Type: endpointUrl.Scheme, UserClientIp: userClientIp, diff --git a/internal/host/plugin/job_set_sync_test.go b/internal/host/plugin/job_set_sync_test.go index 870cd3790a..61fee520f4 100644 --- a/internal/host/plugin/job_set_sync_test.go +++ b/internal/host/plugin/job_set_sync_test.go @@ -201,7 +201,7 @@ func TestSetSyncJob_Run(t *testing.T) { Hosts: []*plgpb.ListHostsResponseHost{ { ExternalId: "first", - IpAddresses: []string{fmt.Sprintf("10.0.0.%d", *counter), testGetIpv6Address(t)}, + IpAddresses: []string{fmt.Sprintf("10.0.0.%d", *counter), testGetIpv6Address(t), "2001:BEEF:0000:0000:0000:0000:0000:0001"}, DnsNames: []string{"foo.com"}, SetIds: setIds, }, @@ -228,13 +228,14 @@ func TestSetSyncJob_Run(t *testing.T) { assert.Len(hosts, 1) for _, host := range hosts { assert.Equal(uint32(1), host.Version) - require.Len(host.IpAddresses, 2) + require.Len(host.IpAddresses, 3) ipv4 := net.ParseIP(host.IpAddresses[0]) require.NotNil(ipv4) require.NotNil(ipv4.To4()) ipv6 := net.ParseIP(host.IpAddresses[1]) require.NotNil(ipv6) require.NotNil(ipv6.To16()) + require.Contains(host.IpAddresses, "2001:beef::1") } require.NoError(rw.LookupByPublicId(ctx, hsa)) diff --git a/internal/host/static/repository_host_test.go b/internal/host/static/repository_host_test.go index cbec7485d1..ad455318e1 100644 --- a/internal/host/static/repository_host_test.go +++ b/internal/host/static/repository_host_test.go @@ -145,7 +145,7 @@ func TestRepository_CreateHost(t *testing.T) { want: &Host{ Host: &store.Host{ CatalogId: catalog.PublicId, - Address: "2001:4860:4860:0:0:0:0:8888", + Address: "2001:4860:4860::8888", }, }, }, @@ -159,36 +159,6 @@ func TestRepository_CreateHost(t *testing.T) { }, wantIsErr: errors.InvalidAddress, }, - { - name: "valid-abbreviated-[ipv6]-address", - in: &Host{ - Host: &store.Host{ - CatalogId: catalog.PublicId, - Address: "[2001:4860:4860::8888]", - }, - }, - want: &Host{ - Host: &store.Host{ - CatalogId: catalog.PublicId, - Address: "[2001:4860:4860::8888]", - }, - }, - }, - { - name: "valid-[ipv6]-address", - in: &Host{ - Host: &store.Host{ - CatalogId: catalog.PublicId, - Address: "[2001:4860:4860:0:0:0:0:8888]", - }, - }, - want: &Host{ - Host: &store.Host{ - CatalogId: catalog.PublicId, - Address: "[2001:4860:4860:0:0:0:0:8888]", - }, - }, - }, { name: "valid-with-name", in: &Host{ @@ -741,39 +711,7 @@ func TestRepository_UpdateHost(t *testing.T) { masks: []string{"Address"}, want: &Host{ Host: &store.Host{ - Address: "2001:4860:4860:0:0:0:0:8888", - }, - }, - wantCount: 1, - }, - { - name: "change-abbreviated-[ipv6]-address", - orig: &Host{ - Host: &store.Host{ - Address: "127.0.0.1", - }, - }, - chgFn: changeAddress("[2001:4860:4860::8888]"), - masks: []string{"Address"}, - want: &Host{ - Host: &store.Host{ - Address: "[2001:4860:4860::8888]", - }, - }, - wantCount: 1, - }, - { - name: "change-[ipv6]-address", - orig: &Host{ - Host: &store.Host{ - Address: "127.0.0.1", - }, - }, - chgFn: changeAddress("[2001:4860:4860:0:0:0:0:8888]"), - masks: []string{"Address"}, - want: &Host{ - Host: &store.Host{ - Address: "[2001:4860:4860:0:0:0:0:8888]", + Address: "2001:4860:4860::8888", }, }, wantCount: 1, diff --git a/internal/ratelimit/handler_test.go b/internal/ratelimit/handler_test.go index 952283c470..398e1ab38a 100644 --- a/internal/ratelimit/handler_test.go +++ b/internal/ratelimit/handler_test.go @@ -67,7 +67,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusOK, http.Header{ @@ -109,7 +109,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusOK, http.Header{ @@ -151,7 +151,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusOK, http.Header{ @@ -193,7 +193,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusOK, http.Header{ @@ -235,7 +235,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusOK, http.Header{ @@ -277,7 +277,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusOK, http.Header{ @@ -319,7 +319,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusOK, http.Header{ @@ -369,7 +369,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusTooManyRequests, http.Header{ @@ -436,7 +436,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusServiceUnavailable, http.Header{ @@ -477,7 +477,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusInternalServerError, http.Header{}, @@ -516,7 +516,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusNotFound, http.Header{}, @@ -555,7 +555,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusBadRequest, http.Header{}, @@ -594,7 +594,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusMethodNotAllowed, http.Header{}, @@ -633,7 +633,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusMethodNotAllowed, http.Header{}, @@ -672,7 +672,7 @@ func TestHandler(t *testing.T) { require.NoError(t, err) return r }, - "[::1]", + "::1", "authtoken", http.StatusMethodNotAllowed, http.Header{}, @@ -756,7 +756,7 @@ func TestHandlerErrors(t *testing.T) { ctx, err = event.NewRequestInfoContext(ctx, &event.RequestInfo{ Id: id, EventId: common.GeneratedTraceId(ctx), - ClientIp: "[::1]", + ClientIp: "::1", }) require.NoError(t, err) return ctx diff --git a/internal/session/session.go b/internal/session/session.go index 5f8b7b9994..152d71b13d 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -473,7 +473,7 @@ func newCert(ctx context.Context, jobId string, addresses []string, exp time.Tim for _, addr := range addresses { // First ensure we aren't looking at ports, regardless of IP or not host, _, err := util.SplitHostPort(addr) - if err != nil { + if err != nil && !errors.Is(err, util.ErrMissingPort) { return nil, nil, errors.Wrap(ctx, err, op) } // Now figure out if it's an IP address or not. If ParseIP likes it, add diff --git a/internal/session/session_connect_with_test.go b/internal/session/session_connect_with_test.go index f156905c8e..a9686b2970 100644 --- a/internal/session/session_connect_with_test.go +++ b/internal/session/session_connect_with_test.go @@ -43,11 +43,11 @@ func TestConnectWith_validate(t *testing.T) { name: "valid-ipv6", fields: fields{ SessionId: id, - ClientTcpAddress: "[::1]", + ClientTcpAddress: "::1", ClientTcpPort: 22, - EndpointTcpAddress: "[::1]", + EndpointTcpAddress: "::1", EndpointTcpPort: 2222, - UserClientIp: "[::2]", + UserClientIp: "::2", }, }, { diff --git a/internal/target/tcp/repository_tcp_target_test.go b/internal/target/tcp/repository_tcp_target_test.go index 807578e318..9955c26605 100644 --- a/internal/target/tcp/repository_tcp_target_test.go +++ b/internal/target/tcp/repository_tcp_target_test.go @@ -127,13 +127,13 @@ func TestRepository_CreateTarget(t *testing.T) { target.WithName("with-abbreviated-ipv6-address"), target.WithDescription("with-abbreviated-ipv6-address"), target.WithDefaultPort(80), - target.WithAddress("2001:4860:4860::8888")) + target.WithAddress("2001:BEEF:4860::8888")) require.NoError(t, err) return target }(), }, wantErr: false, - wantAddress: "2001:4860:4860::8888", + wantAddress: "2001:beef:4860::8888", }, { name: "with-ipv6-address", @@ -143,13 +143,13 @@ func TestRepository_CreateTarget(t *testing.T) { target.WithName("with-ipv6-address"), target.WithDescription("with-ipv6-address"), target.WithDefaultPort(80), - target.WithAddress("2001:4860:4860:0:0:0:0:8888")) + target.WithAddress("2001:BEEF:4860:0:0:0:0:8888")) require.NoError(t, err) return target }(), }, wantErr: false, - wantAddress: "2001:4860:4860::8888", + wantAddress: "2001:beef:4860::8888", }, { name: "with-abbreviated-[ipv6]-address", @@ -164,8 +164,8 @@ func TestRepository_CreateTarget(t *testing.T) { return target }(), }, - wantErr: false, - wantAddress: "2001:4860:4860::8888", + wantErr: true, + wantIsError: errors.InvalidAddress, }, { name: "with-invalid-abbreviated-[ipv6]-address-with-port", @@ -196,8 +196,8 @@ func TestRepository_CreateTarget(t *testing.T) { return target }(), }, - wantErr: false, - wantAddress: "2001:4860:4860:0:0:0:0:8888", + wantErr: true, + wantIsError: errors.InvalidAddress, }, { name: "with-invalid-[ipv6]-address-with-port", @@ -521,13 +521,13 @@ func TestRepository_UpdateTcpTarget(t *testing.T) { name: "valid-abbreviated-ipv6-address" + id, fieldMaskPaths: []string{"Name", "Address"}, ProjectId: proj.PublicId, - address: "2001:4860:4860::8888", + address: "2001:BEEF:4860::8888", }, newProjectId: proj.PublicId, wantErr: false, wantRowsUpdate: 1, wantHostSources: false, - wantAddress: "2001:4860:4860::8888", + wantAddress: "2001:beef:4860::8888", }, { name: "valid-ipv6-address", @@ -535,13 +535,13 @@ func TestRepository_UpdateTcpTarget(t *testing.T) { name: "valid-ipv6-address" + id, fieldMaskPaths: []string{"Name", "Address"}, ProjectId: proj.PublicId, - address: "2001:4860:4860:0:0:0:0:8888", + address: "2001:BEEF:4860:0:0:0:0:8888", }, newProjectId: proj.PublicId, wantErr: false, wantRowsUpdate: 1, wantHostSources: false, - wantAddress: "2001:4860:4860::8888", + wantAddress: "2001:beef:4860::8888", }, { name: "valid-abbreviated-[ipv6]-address", @@ -551,11 +551,10 @@ func TestRepository_UpdateTcpTarget(t *testing.T) { ProjectId: proj.PublicId, address: "[2001:4860:4860::8888]", }, - newProjectId: proj.PublicId, - wantErr: false, - wantRowsUpdate: 1, - wantHostSources: false, - wantAddress: "2001:4860:4860::8888", + newProjectId: proj.PublicId, + wantErr: true, + wantIsError: errors.InvalidAddress, + wantErrMsg: "invalid address", }, { name: "invalid-abbreviated-[ipv6]-address-with-port", @@ -578,11 +577,10 @@ func TestRepository_UpdateTcpTarget(t *testing.T) { ProjectId: proj.PublicId, address: "[2001:4860:4860:0:0:0:0:8888]", }, - newProjectId: proj.PublicId, - wantErr: false, - wantRowsUpdate: 1, - wantHostSources: false, - wantAddress: "2001:4860:4860:0:0:0:0:8888", + newProjectId: proj.PublicId, + wantErr: true, + wantIsError: errors.InvalidAddress, + wantErrMsg: "invalid address", }, { name: "invalid-[ipv6]-address-with-port", diff --git a/internal/tests/api/targets/target_test.go b/internal/tests/api/targets/target_test.go index cf9146b592..333381a92f 100644 --- a/internal/tests/api/targets/target_test.go +++ b/internal/tests/api/targets/target_test.go @@ -380,7 +380,7 @@ func TestTarget_AddressMutualExclusiveRelationship(t *testing.T) { // Create target with a network address association targetResp, err := tClient.Create(tc.Context(), "tcp", proj.GetPublicId(), - targets.WithName("test-address"), targets.WithAddress("[::1]"), targets.WithTcpTargetDefaultPort(22)) + targets.WithName("test-address"), targets.WithAddress("::1"), targets.WithTcpTargetDefaultPort(22)) require.NoError(t, err) require.NotNil(t, targetResp) require.Equal(t, "::1", targetResp.GetItem().Address) @@ -392,7 +392,7 @@ func TestTarget_AddressMutualExclusiveRelationship(t *testing.T) { hs, err := hostsets.NewClient(client).Create(tc.Context(), hc.Item.Id) require.NoError(t, err) require.NotNil(t, hs) - h, err := hosts.NewClient(client).Create(tc.Context(), hc.Item.Id, hosts.WithStaticHostAddress("[::1]")) + h, err := hosts.NewClient(client).Create(tc.Context(), hc.Item.Id, hosts.WithStaticHostAddress("::1")) require.NoError(t, err) require.NotNil(t, h) hUpdate, err := hostsets.NewClient(client).AddHosts(tc.Context(), hs.Item.Id, hs.Item.Version, []string{h.GetItem().Id}) @@ -438,7 +438,7 @@ func TestTarget_HostSourceMutualExclusiveRelationship(t *testing.T) { hs, err := hostsets.NewClient(client).Create(tc.Context(), hc.Item.Id) require.NoError(t, err) require.NotNil(t, hs) - h, err := hosts.NewClient(client).Create(tc.Context(), hc.Item.Id, hosts.WithStaticHostAddress("[::1]")) + h, err := hosts.NewClient(client).Create(tc.Context(), hc.Item.Id, hosts.WithStaticHostAddress("::1")) require.NoError(t, err) require.NotNil(t, h) hUpdate, err := hostsets.NewClient(client).AddHosts(tc.Context(), hs.Item.Id, hs.Item.Version, []string{h.GetItem().Id}) @@ -461,7 +461,7 @@ func TestTarget_HostSourceMutualExclusiveRelationship(t *testing.T) { require.Empty(t, updateResp.GetItem().Address) require.Equal(t, []string{hs.Item.Id}, updateResp.GetItem().HostSourceIds) version = updateResp.GetItem().Version - updateResp, err = tClient.Update(tc.Context(), targetId, version, targets.WithAddress("[::1]")) + updateResp, err = tClient.Update(tc.Context(), targetId, version, targets.WithAddress("::1")) require.Error(t, err) require.Nil(t, updateResp) apiErr := api.AsServerError(err) @@ -474,7 +474,7 @@ func TestTarget_HostSourceMutualExclusiveRelationship(t *testing.T) { require.NotNil(t, updateResp) require.Empty(t, updateResp.GetItem().HostSourceIds) version = updateResp.GetItem().Version - updateResp, err = tClient.Update(tc.Context(), targetId, version, targets.WithAddress("[::1]")) + updateResp, err = tClient.Update(tc.Context(), targetId, version, targets.WithAddress("::1")) require.NoError(t, err) require.NotNil(t, updateResp) require.Equal(t, "::1", updateResp.GetItem().Address) @@ -502,13 +502,13 @@ func TestCreateTarget_DirectlyAttachedAddress(t *testing.T) { }, { name: "target-ipv6-address", - address: "[2001:4860:4860:0:0:0:0:8888]", - expectedAddress: "2001:4860:4860:0:0:0:0:8888", + address: "2001:BEEF:4860:0:0:0:0:8888", + expectedAddress: "2001:beef:4860::8888", }, { name: "target-abbreviated-ipv6-address", - address: "[2001:4860:4860::8888]", - expectedAddress: "2001:4860:4860::8888", + address: "2001:BEEF:4860::8888", + expectedAddress: "2001:beef:4860::8888", }, { name: "target-dns-address", diff --git a/internal/tests/cluster/parallel/unix_listener_test.go b/internal/tests/cluster/parallel/unix_listener_test.go index d125b22f0c..a5d8f22f2f 100644 --- a/internal/tests/cluster/parallel/unix_listener_test.go +++ b/internal/tests/cluster/parallel/unix_listener_test.go @@ -95,6 +95,22 @@ func TestUnixListener(t *testing.T) { helper.ExpectWorkers(t, c1) require.NoError(c1.Controller().Shutdown()) + + conf, err = config.DevController() + require.NoError(err) + + for _, l := range conf.Listeners { + switch l.Purpose[0] { + case "api": + l.Address = path.Join(tempDir, "api") + l.Type = "unix" + + case "cluster": + l.Address = path.Join(tempDir, "cluster") + l.Type = "unix" + } + } + c2 := controller.NewTestController(t, &controller.TestControllerOpts{ Config: conf, Logger: logger.Named("c2"), diff --git a/internal/util/net.go b/internal/util/net.go index 874929906a..f0b6ad97ca 100644 --- a/internal/util/net.go +++ b/internal/util/net.go @@ -10,16 +10,51 @@ import ( "regexp" "strings" - "github.com/hashicorp/boundary/globals" + "github.com/hashicorp/go-secure-stdlib/parseutil" ) const ( - // MinAddressLength + // MinAddressLength is the minimum length for an address. MinAddressLength = 3 - // MaxAddressLength + // MaxAddressLength is the maximum length for an address. MaxAddressLength = 255 ) +var ( + // ErrMissingPort is returned from SplitHostPort when the underlying + // net.SplitHostPort call detects the input did not contain a port. This is + // the case for an input like "127.0.0.1" (but not "127.0.0.1:"). + ErrMissingPort = errors.New("missing port in address") + // ErrTooManyColons is returned from SplitHostPort when the underlying + // net.SplitHostPort call detects the input has more colons than it is + // expected to have. This is the case for an input like + // "127.0.0.1:1010:1010". + ErrTooManyColons = errors.New("too many colons in address") + // ErrMissingRBracket is returned from SplitHostPort when the underlying + // net.SplitHostPort call detects an input that starts with '[' but has no + // corresponding ']' closing bracket. This is the case for an input like + // "[::1:9090". + ErrMissingRBracket = errors.New("missing ']' in address") + // ErrUnexpectedLBracket is returned from SplitHostPort when the underlying + // net.SplitHostPort call detects an input that has an unexpected '[' + // character where it is not supposed to be. This is the case for an input + // like "127.0.[0.1:9090" or "[[127.0.0.1]:9090" (but not + // "[127.0.0.1]:9090"). + ErrUnexpectedLBracket = errors.New("unexpected '[' in address") + // ErrUnexpectedRBracket is returned from SplitHostPort when the underlying + // net.SplitHostPort call detects an input that has an unexpected ']' + // character where it is not supposed to be. This is the case for an input + // like "127.0.]0.1:9090" or "127.0.0.1]:9090" (but not "[127.0.0.1]:9090"). + ErrUnexpectedRBracket = errors.New("unexpected ']' in address") + + // ErrInvalidAddressLength is returned when an address input is not within + // defined lengths (see MinAddressLength and MaxAddressLength). + ErrInvalidAddressLength = errors.New("invalid address length") + // ErrInvalidAddressContainsPort is returned when an address input contains + // a port. + ErrInvalidAddressContainsPort = errors.New("address contains a port") +) + // This regular expression is used to find all instances of square brackets within a string. // This regular expression is used to remove the square brackets from an IPv6 address. var squareBrackets = regexp.MustCompile("\\[|\\]") @@ -31,18 +66,69 @@ func JoinHostPort(host, port string) string { return net.JoinHostPort(host, port) } -// SplitHostPort splits a network address of the form "host:port", "host%zone:port", "[host]:port" or "[host%zone]:port" into host or host%zone and port. +// SplitHostPort splits a network address of the form "host:port", +// "host%zone:port", "[host]:port" or "[host%zone]:port" into separate "host" or +// "host%zone" and "port". It differs from its standard library counterpart in +// the following ways: +// - If the input is an IP address (with no port), this function will return +// that IP as the `host`, empty `port`, and ErrMissingPort. +// - If the input is just a host (with no port), this function will return +// that host as the `host`, empty `port`, and ErrMissingPort. // -// A literal IPv6 address in hostport must be enclosed in square brackets, as in "[::1]:80", "[::1%lo0]:80". +// These changes enable inputs like "ip_address" or "host" and allows callers to +// detect whether any given `hostport` contains a port or is just a host/IP. func SplitHostPort(hostport string) (host string, port string, err error) { + // In case `hostport` is just an ip, we can grab that early. + if ip := net.ParseIP(hostport); ip != nil { + // If ParseIP successfully parsed it, it means `hostport` does not have + // a port (or is a malformed IPv6 address like "::1:1234"). + host = ip.String() + err = ErrMissingPort + return + } + + // At this time, we don't necessarily know that `hostport` is a string + // composed of a host and a port, however net.SplitHostPort will error if + // that is not the case. host, port, err = net.SplitHostPort(hostport) - // use the hostport value as a backup when we have a missing port error - if err != nil && strings.Contains(err.Error(), globals.MissingPortErrStr) { - // incase the hostport value is an ipv6, we must remove the enclosed square - // brackets to retain the same behavior as the net.SplitHostPort() method - host = squareBrackets.ReplaceAllString(hostport, "") - err = nil + if err != nil { + addrErr := new(net.AddrError) + isAddrErr := errors.As(err, &addrErr) + if !isAddrErr { + return + } + + // Since net.SplitHostPort does not type the error reason, we'll handle + // that here to simplify logic in callers of this function. Note that + // while this list covers every error state in net.SplitHostPort up to + // Go 1.24.1, error reasons might expand over time. + // See: https://cs.opensource.google/go/go/+/refs/tags/go1.24.1:src/net/ipsock.go;l=165-218 + const ( + stdlibErrReasonMissingPort = "missing port in address" + stdlibErrReasonTooManyColons = "too many colons in address" + stdlibErrReasonMissingRBracket = "missing ']' in address" + stdlibErrReasonUnexpectedLBracket = "unexpected '[' in address" + stdlibErrReasonUnexpectedRBracket = "unexpected ']' in address" + ) + switch { + case strings.Contains(addrErr.Err, stdlibErrReasonMissingPort): + // In case the `hostport` value is an IPv6 address, we must remove + // the brackets (if they exist) to retain the same behavior as + // net.SplitHostPort. This case wouldn't be caught by net.ParseIP + // because "[ipv6_address]" is not a valid input to that function. + host = squareBrackets.ReplaceAllString(hostport, "") + err = ErrMissingPort + case strings.Contains(addrErr.Err, stdlibErrReasonTooManyColons): + err = ErrTooManyColons + case strings.Contains(addrErr.Err, stdlibErrReasonMissingRBracket): + err = ErrMissingRBracket + case strings.Contains(addrErr.Err, stdlibErrReasonUnexpectedLBracket): + err = ErrUnexpectedLBracket + case strings.Contains(addrErr.Err, stdlibErrReasonUnexpectedRBracket): + err = ErrUnexpectedRBracket + } } + return } @@ -55,18 +141,11 @@ func ParseAddress(ctx context.Context, address string) (string, error) { const op = "util.ParseAddress" address = strings.TrimSpace(address) if len(address) < MinAddressLength || len(address) > MaxAddressLength { - return "", errors.New("invalid address length") - } - host, port, err := SplitHostPort(address) - if err != nil { - ip := net.ParseIP(address) - if ip.To4() == nil && ip.To16() == nil { - return "", err - } - host = ip.String() + return "", ErrInvalidAddressLength } + _, port, _ := SplitHostPort(address) if port != "" { - return "", errors.New("address contains a port") + return "", ErrInvalidAddressContainsPort } - return host, nil + return parseutil.NormalizeAddr(address) } diff --git a/internal/util/net_test.go b/internal/util/net_test.go index 1767a9bfe9..ba755a252a 100644 --- a/internal/util/net_test.go +++ b/internal/util/net_test.go @@ -195,11 +195,11 @@ func Test_SplitHostPort(t *testing.T) { }) tests := []struct { - name string - hostport string - expectedHost string - expectedPort string - expectedErrMsg string + name string + hostport string + expectedHost string + expectedPort string + expectedErr error }{ { name: "local-ipv4", @@ -214,9 +214,10 @@ func Test_SplitHostPort(t *testing.T) { expectedPort: "80", }, { - name: "ipv4-ignore-missing-port", + name: "ipv4-missing-port", hostport: "8.8.8.8", expectedHost: "8.8.8.8", + expectedErr: ErrMissingPort, }, { name: "ipv4-empty-port", @@ -224,20 +225,22 @@ func Test_SplitHostPort(t *testing.T) { expectedHost: "8.8.8.8", }, { - name: "ipv4-square-bracket", + name: "ipv4-square-brackets", hostport: "[8.8.8.8]:80", expectedHost: "8.8.8.8", expectedPort: "80", }, { - name: "ipv6-missing-square-brackets", - hostport: "::1:80", - expectedErrMsg: "address ::1:80: too many colons in address", + name: "ipv6-square-brackets", + hostport: "::1:80", + expectedHost: "::1:80", + expectedErr: ErrMissingPort, }, { - name: "ipv6-ignore-missing-port", + name: "ipv6-missing-port", hostport: "[::1]", expectedHost: "::1", + expectedErr: ErrMissingPort, }, { name: "ipv6-empty-port", @@ -266,16 +269,14 @@ func Test_SplitHostPort(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - require, assert := require.New(t), assert.New(t) actualHost, actualPort, err := SplitHostPort(tt.hostport) - if tt.expectedErrMsg != "" { - require.Error(err) - assert.ErrorContains(err, tt.expectedErrMsg) - return + if tt.expectedErr != nil { + require.ErrorIs(t, err, tt.expectedErr) + } else { + require.NoError(t, err) } - require.NoError(err) - assert.Equal(tt.expectedHost, actualHost) - assert.Equal(tt.expectedPort, actualPort) + require.Equal(t, tt.expectedHost, actualHost) + require.Equal(t, tt.expectedPort, actualPort) }) } } @@ -333,14 +334,14 @@ func Test_ParseAddress(t *testing.T) { expectedAddress: "2001:4860:4860::8888", }, { - name: "valid-[ipv6]", - address: "[2001:4860:4860:0:0:0:0:8888]", - expectedAddress: "2001:4860:4860:0:0:0:0:8888", + name: "valid-[ipv6]", + address: "[2001:4860:4860:0:0:0:0:8888]", + expectedErrMsg: "address cannot be encapsulated by brackets", }, { - name: "valid-[ipv6]:", - address: "[2001:4860:4860:0:0:0:0:8888]:", - expectedAddress: "2001:4860:4860:0:0:0:0:8888", + name: "valid-[ipv6]:", + address: "[2001:4860:4860:0:0:0:0:8888]:", + expectedErrMsg: "url has malformed host: missing port value after colon", }, { name: "invalid-ipv6-with-port", @@ -353,14 +354,14 @@ func Test_ParseAddress(t *testing.T) { expectedAddress: "2001:4860:4860::8888", }, { - name: "valid-abbreviated-[ipv6]", - address: "[2001:4860:4860::8888]", - expectedAddress: "2001:4860:4860::8888", + name: "valid-abbreviated-[ipv6]", + address: "[2001:4860:4860::8888]", + expectedErrMsg: "address cannot be encapsulated by brackets", }, { - name: "valid-abbreviated-[ipv6]:", - address: "[2001:4860:4860::8888]:", - expectedAddress: "2001:4860:4860::8888", + name: "valid-abbreviated-[ipv6]:", + address: "[2001:4860:4860::8888]:", + expectedErrMsg: "url has malformed host: missing port value after colon", }, { name: "invalid-abbreviated-[ipv6]-with-port",