diff --git a/internal/cmd/commands/connect/client_tls_config.go b/internal/cmd/commands/connect/client_tls_config.go new file mode 100644 index 0000000000..213b6f4284 --- /dev/null +++ b/internal/cmd/commands/connect/client_tls_config.go @@ -0,0 +1,69 @@ +package connect + +import ( + "crypto/ed25519" + "crypto/tls" + "crypto/x509" + "fmt" + + targetspb "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/targets" +) + +// ClientTlsConfig creates a TLS configuration from the session authorization +// data and host +func ClientTlsConfig(sessionAuthzData *targetspb.SessionAuthorizationData, host string) (*tls.Config, error) { + const op = "connect.ClientTlsConfig" + if sessionAuthzData == nil { + return nil, fmt.Errorf("%s: nil session authorization data", op) + } + parsedCert, err := x509.ParseCertificate(sessionAuthzData.Certificate) + if err != nil { + return nil, fmt.Errorf("unable to decode mTLS certificate: %w", err) + } + + certPool := x509.NewCertPool() + certPool.AddCert(parsedCert) + + tlsConf := &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{sessionAuthzData.Certificate}, + PrivateKey: ed25519.PrivateKey(sessionAuthzData.PrivateKey), + Leaf: parsedCert, + }, + }, + ServerName: host, + MinVersion: tls.VersionTLS13, + NextProtos: []string{"http/1.1", sessionAuthzData.SessionId}, + + // This is set this way so we can make use of VerifyConnection, which we + // set on this TLS config below. We are not skipping verification! + InsecureSkipVerify: true, + } + if host == "" { + tlsConf.ServerName = parsedCert.DNSNames[0] + } + + // We disable normal DNS SAN behavior as we don't rely on DNS or IP + // addresses for security and want to avoid issues with including localhost + // etc. + verifyOpts := x509.VerifyOptions{ + DNSName: sessionAuthzData.SessionId, + Roots: certPool, + KeyUsages: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + x509.ExtKeyUsageServerAuth, + }, + } + tlsConf.VerifyConnection = func(cs tls.ConnectionState) error { + // Go will not run this without at least one peer certificate, but + // doesn't hurt to check + if len(cs.PeerCertificates) == 0 { + return fmt.Errorf("%s: no peer certificates provided", op) + } + _, err := cs.PeerCertificates[0].Verify(verifyOpts) + return err + } + + return tlsConf, nil +} diff --git a/internal/cmd/commands/connect/connect.go b/internal/cmd/commands/connect/connect.go index 5def593f07..4c66174995 100644 --- a/internal/cmd/commands/connect/connect.go +++ b/internal/cmd/commands/connect/connect.go @@ -2,9 +2,7 @@ package connect import ( "context" - "crypto/ed25519" "crypto/tls" - "crypto/x509" "encoding/json" "errors" "fmt" @@ -437,13 +435,12 @@ func (c *Command) Run(args []string) (retCode int) { } } - parsedCert, err := x509.ParseCertificate(c.sessionAuthzData.Certificate) + tlsConf, err := ClientTlsConfig(c.sessionAuthzData, workerHost) if err != nil { - c.PrintCliError(fmt.Errorf("Unable to decode mTLS certificate: %w", err)) - return base.CommandUserError + c.PrintCliError(fmt.Errorf("Error creating TLS configuration: %w", err)) + return base.CommandCliError } - - c.expiration = parsedCert.NotAfter + c.expiration = tlsConf.Certificates[0].Leaf.NotAfter // We don't _rely_ on client-side timeout verification but this prevents us // seeming to be ready for a connection that will immediately fail when we @@ -451,23 +448,6 @@ func (c *Command) Run(args []string) (retCode int) { c.proxyCtx, c.proxyCancel = context.WithDeadline(c.Context, c.expiration) defer c.proxyCancel() - certPool := x509.NewCertPool() - certPool.AddCert(parsedCert) - - tlsConf := &tls.Config{ - Certificates: []tls.Certificate{ - { - Certificate: [][]byte{c.sessionAuthzData.Certificate}, - PrivateKey: ed25519.PrivateKey(c.sessionAuthzData.PrivateKey), - Leaf: parsedCert, - }, - }, - RootCAs: certPool, - ServerName: workerHost, - MinVersion: tls.VersionTLS13, - NextProtos: []string{"http/1.1", c.sessionAuthzData.SessionId}, - } - transport := cleanhttp.DefaultTransport() transport.DisableKeepAlives = false // This isn't/shouldn't used anyways really because the connection is diff --git a/internal/servers/worker/listeners.go b/internal/servers/worker/listeners.go index 12d9692ad2..ce259a7009 100644 --- a/internal/servers/worker/listeners.go +++ b/internal/servers/worker/listeners.go @@ -113,7 +113,7 @@ func (w *Worker) stopHttpServersAndListeners() error { err := ln.ProxyListener.Close() err = listenerCloseErrorCheck(ln.Config.Type, err) if err != nil { - multierror.Append(closeErrors, err) + closeErrors = multierror.Append(closeErrors, err) } } @@ -133,7 +133,7 @@ func (w *Worker) stopAnyListeners() error { err := ln.ProxyListener.Close() err = listenerCloseErrorCheck(ln.Config.Type, err) if err != nil { - multierror.Append(closeErrors, err) + closeErrors = multierror.Append(closeErrors, err) } } diff --git a/internal/servers/worker/worker.go b/internal/servers/worker/worker.go index 64c428edcf..beca2f74cd 100644 --- a/internal/servers/worker/worker.go +++ b/internal/servers/worker/worker.go @@ -63,6 +63,10 @@ type Worker struct { // request. It can be set via startup in New below, or (eventually) via // SIGHUP. updateTags ua.Bool + + // Test-specific options + TestOverrideX509VerifyDnsName string + TestOverrideX509VerifyCertPool *x509.CertPool } func New(conf *Config) (*Worker, error) { @@ -324,10 +328,40 @@ func (w *Worker) getSessionTls(hello *tls.ClientHelloInfo) (*tls.Config, error) }, }, NextProtos: []string{"http/1.1"}, - ServerName: sessionId, - ClientAuth: tls.RequireAndVerifyClientCert, - ClientCAs: certPool, MinVersion: tls.VersionTLS13, + + // These two are set this way so we can make use of VerifyConnection, + // which we set on this TLS config below. We are not skipping + // verification! + ClientAuth: tls.RequireAnyClientCert, + InsecureSkipVerify: true, + } + + // We disable normal DNS SAN behavior as we don't rely on DNS or IP + // addresses for security and want to avoid issues with including localhost + // etc. + verifyOpts := x509.VerifyOptions{ + DNSName: sessionId, + Roots: certPool, + KeyUsages: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + x509.ExtKeyUsageServerAuth, + }, + } + if w.TestOverrideX509VerifyCertPool != nil { + verifyOpts.Roots = w.TestOverrideX509VerifyCertPool + } + if w.TestOverrideX509VerifyDnsName != "" { + verifyOpts.DNSName = w.TestOverrideX509VerifyDnsName + } + tlsConf.VerifyConnection = func(cs tls.ConnectionState) error { + // Go will not run this without at least one peer certificate, but + // doesn't hurt to check + if len(cs.PeerCertificates) == 0 { + return errors.New("no peer certificates provided") + } + _, err := cs.PeerCertificates[0].Verify(verifyOpts) + return err } si := &session.Info{ diff --git a/internal/tests/cluster/x509_verification_test.go b/internal/tests/cluster/x509_verification_test.go new file mode 100644 index 0000000000..bb346dc6db --- /dev/null +++ b/internal/tests/cluster/x509_verification_test.go @@ -0,0 +1,243 @@ +package cluster + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net/http" + "sync" + "testing" + "time" + + "github.com/hashicorp/boundary/api/targets" + "github.com/hashicorp/boundary/globals" + "github.com/hashicorp/boundary/internal/cmd/config" + "github.com/hashicorp/boundary/internal/observability/event" + "github.com/hashicorp/boundary/internal/servers/controller" + "github.com/hashicorp/boundary/internal/servers/worker" + "github.com/hashicorp/boundary/internal/tests/helper" + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/require" + "nhooyr.io/websocket" +) + +func TestCustomX509Verification_Client(t *testing.T) { + req := require.New(t) + ctx := context.Background() + ec := event.TestEventerConfig(t, "TestWorkerReplay", event.TestWithObservationSink(t), event.TestWithSysSink(t)) + testLock := &sync.Mutex{} + logger := hclog.New(&hclog.LoggerOptions{ + Mutex: testLock, + Name: "test", + }) + req.NoError(event.InitSysEventer(logger, testLock, "use-TestCustomX509Verification", event.WithEventerConfig(&ec.EventerConfig))) + + conf, err := config.DevController() + conf.Eventing = &ec.EventerConfig + req.NoError(err) + c1 := controller.NewTestController(t, &controller.TestControllerOpts{ + Config: conf, + InitialResourcesSuffix: "1234567890", + Logger: logger.Named("c1"), + }) + t.Cleanup(c1.Shutdown) + + conf, err = config.DevWorker() + conf.Eventing = &ec.EventerConfig + req.NoError(err) + w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{ + WorkerAuthKms: c1.Config().WorkerAuthKms, + InitialControllers: c1.ClusterAddrs(), + Logger: logger.Named("w1"), + }) + t.Cleanup(w1.Shutdown) + + // Give time for it to connect + time.Sleep(10 * time.Second) + + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + req.NoError(err) + + // Connect target + client := c1.Client() + client.SetToken(c1.Token().Token) + tcl := targets.NewClient(client) + tgt, err := tcl.Read(ctx, "ttcp_1234567890") + req.NoError(err) + req.NotNil(tgt) + + // Create test server, update default port on target + ts := helper.NewTestTcpServer(t) + require.NotNil(t, ts) + t.Cleanup(ts.Close) + tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port()), targets.WithSessionConnectionLimit(-1)) + req.NoError(err) + req.NotNil(tgt) + + tests := []struct { + name string + certPool *x509.CertPool + dnsName *bytes.Buffer + wantErrContains string + }{ + { + name: "base", + }, + { + name: "modified cert pool", + certPool: x509.NewCertPool(), + wantErrContains: "signed by unknown authority", + }, + { + name: "modified dns name", + dnsName: bytes.NewBuffer([]byte("foobar")), + wantErrContains: "not foobar", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + sess := helper.NewTestSession(ctx, t, tcl, "ttcp_1234567890") + + certPool := tc.certPool + if certPool == nil { + parsedCert, err := x509.ParseCertificate(sess.SessionAuthzData.Certificate) + require.NoError(err) + certPool = x509.NewCertPool() + certPool.AddCert(parsedCert) + } + + dnsName := sess.SessionId + if tc.dnsName != nil { + dnsName = tc.dnsName.String() + } + + verifyOpts := x509.VerifyOptions{ + DNSName: dnsName, + Roots: certPool, + KeyUsages: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + x509.ExtKeyUsageServerAuth, + }, + } + sess.Transport.TLSClientConfig.VerifyConnection = func(cs tls.ConnectionState) error { + // Go will not run this without at least one peer certificate, but + // doesn't hurt to check + if len(cs.PeerCertificates) == 0 { + return errors.New("no peer certificates provided") + } + _, err := cs.PeerCertificates[0].Verify(verifyOpts) + return err + } + conn, _, err := websocket.Dial( + ctx, + fmt.Sprintf("wss://%s/v1/proxy", sess.WorkerAddr), + &websocket.DialOptions{ + HTTPClient: &http.Client{ + Transport: sess.Transport, + }, + Subprotocols: []string{globals.TcpProxyV1}, + }, + ) + if conn != nil { + defer conn.Close(websocket.StatusNormalClosure, "done") + } + switch tc.wantErrContains != "" { + case true: + require.Error(err) + require.Contains(err.Error(), tc.wantErrContains) + default: + require.NoError(err) + } + }) + } +} + +func TestCustomX509Verification_Server(t *testing.T) { + ec := event.TestEventerConfig(t, "TestCustomX509Verification_Server", event.TestWithObservationSink(t), event.TestWithSysSink(t)) + testLock := &sync.Mutex{} + logger := hclog.New(&hclog.LoggerOptions{ + Mutex: testLock, + Name: "test", + }) + require.NoError(t, event.InitSysEventer(logger, testLock, "use-TestCustomX509Verification_Server", event.WithEventerConfig(&ec.EventerConfig))) + + t.Run("bad cert pool", testCustomX509Verification_Server(ec, x509.NewCertPool(), "", "bad certificate")) + t.Run("bad dns name", testCustomX509Verification_Server(ec, nil, "foobar", "bad certificate")) +} + +func testCustomX509Verification_Server(ec event.TestConfig, certPool *x509.CertPool, dnsName, wantErrContains string) func(t *testing.T) { + return func(t *testing.T) { + t.Parallel() + req := require.New(t) + ctx := context.Background() + + conf, err := config.DevController() + conf.Eventing = &ec.EventerConfig + req.NoError(err) + c1 := controller.NewTestController(t, &controller.TestControllerOpts{ + Config: conf, + InitialResourcesSuffix: "1234567890", + }) + t.Cleanup(c1.Shutdown) + + conf, err = config.DevWorker() + conf.Eventing = &ec.EventerConfig + req.NoError(err) + w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{ + WorkerAuthKms: c1.Config().WorkerAuthKms, + InitialControllers: c1.ClusterAddrs(), + }) + w1.Worker().TestOverrideX509VerifyCertPool = certPool + w1.Worker().TestOverrideX509VerifyDnsName = dnsName + t.Cleanup(w1.Shutdown) + + // Give time for it to connect + time.Sleep(10 * time.Second) + + err = w1.Worker().WaitForNextSuccessfulStatusUpdate() + req.NoError(err) + + // Connect target + client := c1.Client() + client.SetToken(c1.Token().Token) + tcl := targets.NewClient(client) + tgt, err := tcl.Read(ctx, "ttcp_1234567890") + req.NoError(err) + req.NotNil(tgt) + + // Create test server, update default port on target + ts := helper.NewTestTcpServer(t) + require.NotNil(t, ts) + t.Cleanup(ts.Close) + tgt, err = tcl.Update(ctx, tgt.Item.Id, tgt.Item.Version, targets.WithTcpTargetDefaultPort(ts.Port()), targets.WithSessionConnectionLimit(-1)) + req.NoError(err) + req.NotNil(tgt) + + sess := helper.NewTestSession(ctx, t, tcl, "ttcp_1234567890") + + conn, _, err := websocket.Dial( + ctx, + fmt.Sprintf("wss://%s/v1/proxy", sess.WorkerAddr), + &websocket.DialOptions{ + HTTPClient: &http.Client{ + Transport: sess.Transport, + }, + Subprotocols: []string{globals.TcpProxyV1}, + }, + ) + if conn != nil { + defer conn.Close(websocket.StatusNormalClosure, "done") + } + switch wantErrContains != "" { + case true: + req.Error(err) + req.Contains(err.Error(), wantErrContains) + default: + req.NoError(err) + } + } +} diff --git a/internal/tests/helper/testing_helper.go b/internal/tests/helper/testing_helper.go index 9ee2cad148..fdde7f160c 100644 --- a/internal/tests/helper/testing_helper.go +++ b/internal/tests/helper/testing_helper.go @@ -2,9 +2,6 @@ package helper import ( "context" - "crypto/ed25519" - "crypto/tls" - "crypto/x509" "encoding/binary" "errors" "fmt" @@ -18,6 +15,7 @@ import ( "github.com/hashicorp/boundary/api/targets" "github.com/hashicorp/boundary/globals" + "github.com/hashicorp/boundary/internal/cmd/commands/connect" "github.com/hashicorp/boundary/internal/proxy" "github.com/hashicorp/boundary/internal/servers/controller/common" "github.com/hashicorp/boundary/internal/servers/worker" @@ -54,11 +52,12 @@ const ( // TestSession represents an authorized session. type TestSession struct { - sessionId string - workerAddr string - transport *http.Transport - tofuToken string - connectionsLeft int32 + SessionId string + WorkerAddr string + Transport *http.Transport + tofuToken string + connectionsLeft int32 + SessionAuthzData *targetspb.SessionAuthorizationData } // NewTestSession authorizes a session and creates all of the data @@ -76,43 +75,27 @@ func NewTestSession( require.NotNil(sar) s := &TestSession{ - sessionId: sar.Item.SessionId, + SessionId: sar.Item.SessionId, } authzString := sar.GetItem().(*targets.SessionAuthorization).AuthorizationToken marshaled, err := base58.FastBase58Decoding(authzString) require.NoError(err) require.NotZero(marshaled) - sessionAuthzData := new(targetspb.SessionAuthorizationData) - err = proto.Unmarshal(marshaled, sessionAuthzData) + s.SessionAuthzData = new(targetspb.SessionAuthorizationData) + err = proto.Unmarshal(marshaled, s.SessionAuthzData) require.NoError(err) - require.NotZero(sessionAuthzData.GetWorkerInfo()) + require.NotZero(s.SessionAuthzData.GetWorkerInfo()) - s.workerAddr = sessionAuthzData.GetWorkerInfo()[0].GetAddress() + s.WorkerAddr = s.SessionAuthzData.GetWorkerInfo()[0].GetAddress() - parsedCert, err := x509.ParseCertificate(sessionAuthzData.Certificate) + tlsConf, err := connect.ClientTlsConfig(s.SessionAuthzData, "") require.NoError(err) - require.Len(parsedCert.DNSNames, 1) - - certPool := x509.NewCertPool() - certPool.AddCert(parsedCert) - tlsConf := &tls.Config{ - Certificates: []tls.Certificate{ - { - Certificate: [][]byte{sessionAuthzData.Certificate}, - PrivateKey: ed25519.PrivateKey(sessionAuthzData.PrivateKey), - Leaf: parsedCert, - }, - }, - RootCAs: certPool, - ServerName: parsedCert.DNSNames[0], - MinVersion: tls.VersionTLS13, - } - s.transport = cleanhttp.DefaultTransport() - s.transport.DisableKeepAlives = false - s.transport.TLSClientConfig = tlsConf - s.transport.IdleConnTimeout = 0 + s.Transport = cleanhttp.DefaultTransport() + s.Transport.DisableKeepAlives = false + s.Transport.TLSClientConfig = tlsConf + s.Transport.IdleConnTimeout = 0 return s } @@ -126,10 +109,10 @@ func (s *TestSession) connect(ctx context.Context, t *testing.T) net.Conn { require := require.New(t) conn, resp, err := websocket.Dial( ctx, - fmt.Sprintf("wss://%s/v1/proxy", s.workerAddr), + fmt.Sprintf("wss://%s/v1/proxy", s.WorkerAddr), &websocket.DialOptions{ HTTPClient: &http.Client{ - Transport: s.transport, + Transport: s.Transport, }, Subprotocols: []string{globals.TcpProxyV1}, }, @@ -184,7 +167,7 @@ func (s *TestSession) ExpectConnectionStateOnController( connectionRepo, err := connectionRepoFn() require.NoError(err) - conns, err := connectionRepo.ListConnectionsBySessionId(ctx, s.sessionId) + conns, err := connectionRepo.ListConnectionsBySessionId(ctx, s.SessionId) require.NoError(err) // To avoid misleading passing tests, we require this test be used // with sessions with connections.. @@ -290,7 +273,7 @@ func (s *TestSession) ExpectConnectionStateOnWorker( func (s *TestSession) testWorkerConnectionInfo(t *testing.T, tw *worker.TestWorker) map[string]worker.TestConnectionInfo { t.Helper() require := require.New(t) - si, ok := tw.LookupSession(s.sessionId) + si, ok := tw.LookupSession(s.SessionId) // This is always an error if the session has been removed from the // local state. require.True(ok)