From 4bae4f23f593bc55e4d1c58ed88fbd11731367d8 Mon Sep 17 00:00:00 2001 From: Louis Ruch Date: Wed, 19 Nov 2025 16:24:20 -0800 Subject: [PATCH] feat(ssh): support ssh known hosts file (#6263) --- internal/cmd/commands/dev/dev.go | 8 + internal/cmd/config/config.go | 5 + internal/daemon/worker/handler.go | 8 +- internal/daemon/worker/proxy/options.go | 10 + internal/daemon/worker/proxy/options_test.go | 13 ++ internal/daemon/worker/worker.go | 27 +++ internal/daemon/worker/worker_test.go | 212 ++++++++++++++++++- 7 files changed, 281 insertions(+), 2 deletions(-) diff --git a/internal/cmd/commands/dev/dev.go b/internal/cmd/commands/dev/dev.go index 803fb5a03a..4df1c8dbfd 100644 --- a/internal/cmd/commands/dev/dev.go +++ b/internal/cmd/commands/dev/dev.go @@ -113,6 +113,7 @@ type Command struct { flagWorkerAuthCaCertificateLifetime time.Duration flagWorkerAuthDebuggingEnabled bool flagWorkerRecordingStorageDir string + flagSshKnownHostsPath string flagWorkerRecordingStorageMinimumAvailableCapacity string flagBsrKey string } @@ -427,6 +428,12 @@ func (c *Command) Flags() *base.FlagSets { Usage: "Specifies the directory to store worker session recordings in dev mode. If not provided a temp directory will be created. Session recording is an Enterprise-only feature.", }) + f.StringVar(&base.StringVar{ + Name: "worker-ssh-known-hosts-path", + Target: &c.flagSshKnownHostsPath, + Usage: "Specifies the path of the known_hosts file to be used by the worker for SSH host key verification of an SSH target in dev mode. SSH targets and SSH credential injection are Enterprise-only features.", + }) + f.StringVar(&base.StringVar{ Name: "worker-recording-storage-minimum-available-capacity", Target: &c.flagWorkerRecordingStorageMinimumAvailableCapacity, @@ -533,6 +540,7 @@ func (c *Command) Run(args []string) int { c.Config.Plugins.ExecutionDir = c.flagPluginExecutionDir if !c.flagControllerOnly { + c.Config.Worker.SshKnownHostsPath = c.flagSshKnownHostsPath c.Config.Worker.AuthStoragePath = c.flagWorkerAuthStorageDir c.Config.Worker.RecordingStoragePath = c.flagWorkerRecordingStorageDir c.Config.Worker.RecordingStorageMinimumAvailableCapacity = c.flagWorkerRecordingStorageMinimumAvailableCapacity diff --git a/internal/cmd/config/config.go b/internal/cmd/config/config.go index f2f39e384e..7935977246 100644 --- a/internal/cmd/config/config.go +++ b/internal/cmd/config/config.go @@ -393,6 +393,11 @@ type Worker struct { // they are sync'ed to the corresponding storage bucket. The path must already exist. RecordingStoragePath string `hcl:"recording_storage_path"` + // SshKnownHostsPath represents the location of the known_hosts file to be used by the worker + // for SSH host key verification when connecting to ssh targets. The path must already exist. + // If not provided the worker will skip host key verification. + SshKnownHostsPath string `hcl:"ssh_known_hosts_path"` + // RecordingStorageMinimumAvailableCapacity represents the minimum amount of available // disk space a worker needs in the path defined by RecordingStoragePath for processing // sessions with recording enabled. The expected input value for this field is a diff --git a/internal/daemon/worker/handler.go b/internal/daemon/worker/handler.go index 96781d6aaa..9065e604d6 100644 --- a/internal/daemon/worker/handler.go +++ b/internal/daemon/worker/handler.go @@ -280,7 +280,13 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "error getting decryption function") event.WriteError(ctx, op, err) } - runProxy, err := handleProxyFn(ctx, ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx, w.recorderManager, proxyHandlers.WithLogger(w.logger)) + + handlerOpts := []proxyHandlers.Option{proxyHandlers.WithLogger(w.logger)} + if cb := w.SshKnownHostsCallback.Load(); cb != nil { + handlerOpts = append(handlerOpts, proxyHandlers.WithSshHostKeyCallback(*cb)) + } + + runProxy, err := handleProxyFn(ctx, ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx, w.recorderManager, handlerOpts...) if err != nil { conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "unable to setup proxying") diff --git a/internal/daemon/worker/proxy/options.go b/internal/daemon/worker/proxy/options.go index f029fa19c1..7bdceecafb 100644 --- a/internal/daemon/worker/proxy/options.go +++ b/internal/daemon/worker/proxy/options.go @@ -8,6 +8,7 @@ import ( serverpb "github.com/hashicorp/boundary/internal/gen/controller/servers/services" "github.com/hashicorp/go-hclog" + "golang.org/x/crypto/ssh" ) // Option - how Options are passed as arguments. @@ -30,6 +31,7 @@ type Options struct { WithTestKdcAddress string WithTestKerberosServerHostname string WithLogger hclog.Logger + WithSshHostKeyCallback ssh.HostKeyCallback } func getDefaultOptions() Options { @@ -87,3 +89,11 @@ func WithLogger(l hclog.Logger) Option { o.WithLogger = l } } + +// WithSshHostKeyCallback allows specifying a ssh.HostKeyCallback function +// to be used for host key verification. +func WithSshHostKeyCallback(with ssh.HostKeyCallback) Option { + return func(o *Options) { + o.WithSshHostKeyCallback = with + } +} diff --git a/internal/daemon/worker/proxy/options_test.go b/internal/daemon/worker/proxy/options_test.go index b547b08c2a..1b20f9e5bd 100644 --- a/internal/daemon/worker/proxy/options_test.go +++ b/internal/daemon/worker/proxy/options_test.go @@ -4,6 +4,7 @@ package proxy import ( + "crypto/ed25519" "net" "reflect" "runtime" @@ -11,6 +12,7 @@ import ( serverpb "github.com/hashicorp/boundary/internal/gen/controller/servers/services" "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ssh" ) func Test_GetOpts(t *testing.T) { @@ -63,4 +65,15 @@ func Test_GetOpts(t *testing.T) { testOpts.WithTestKerberosServerHostname = testKerberosServerHostname assert.Equal(opts, testOpts) }) + t.Run("WithSshHostKeyCallback", func(t *testing.T) { + assert := assert.New(t) + opts := getDefaultOptions() + assert.Nil(opts.WithSshHostKeyCallback) + + signer, err := ssh.NewSignerFromKey(ed25519.NewKeyFromSeed([]byte("foobfoobfoobfoobfoobfoobfoobfoob"))) + assert.NoError(err) + + opts = GetOpts(WithSshHostKeyCallback(ssh.FixedHostKey(signer.PublicKey()))) + assert.NotNil(opts.WithSshHostKeyCallback) + }) } diff --git a/internal/daemon/worker/worker.go b/internal/daemon/worker/worker.go index f792c24234..65a122025e 100644 --- a/internal/daemon/worker/worker.go +++ b/internal/daemon/worker/worker.go @@ -49,6 +49,8 @@ import ( "github.com/mr-tron/base58" "github.com/prometheus/client_golang/prometheus" ua "go.uber.org/atomic" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" "google.golang.org/grpc" "google.golang.org/grpc/resolver/manual" "google.golang.org/protobuf/proto" @@ -230,6 +232,10 @@ type Worker struct { downstreamConnManager *cluster.DownstreamManager HostServiceServer wpbs.HostServiceServer + + // SshKnownHostsCallback is used to provide a ssh.HostKeyCallback for SSH host key verification + // when connecting to an SSH target. This is an atomic because it can be updated at runtime via SIGHUP. + SshKnownHostsCallback atomic.Pointer[ssh.HostKeyCallback] } func New(ctx context.Context, conf *Config) (*Worker, error) { @@ -286,6 +292,14 @@ func New(ctx context.Context, conf *Config) (*Worker, error) { w.localStorageState.Store(server.NotConfiguredLocalStorageState) } + if w.conf.RawConfig.Worker.SshKnownHostsPath != "" { + cb, err := knownhosts.New(w.conf.RawConfig.Worker.SshKnownHostsPath) + if err != nil { + return nil, fmt.Errorf("error loading ssh known hosts file: %w", err) + } + w.SshKnownHostsCallback.Store(&cb) + } + pluginLogger, err := event.NewHclogLogger(ctx, w.conf.Server.Eventer) if err != nil { return nil, fmt.Errorf("error creating plugin logger: %w", err) @@ -509,6 +523,19 @@ func (w *Worker) Reload(ctx context.Context, newConf *config.Config) { default: w.getDownstreamWorkersTimeoutDuration.Store(int64(newConf.Worker.GetDownstreamWorkersTimeoutDuration)) } + + switch newConf.Worker.SshKnownHostsPath { + case "": + w.SshKnownHostsCallback.Store(nil) + default: + cb, err := knownhosts.New(newConf.Worker.SshKnownHostsPath) + if err != nil { + event.WriteError(w.baseContext, op, fmt.Errorf("error loading ssh known hosts file: %w", err)) + break + } + w.SshKnownHostsCallback.Store(&cb) + } + // See comment about this in worker.go session.CloseCallTimeout.Store(w.successfulRoutingInfoGracePeriod.Load()) diff --git a/internal/daemon/worker/worker_test.go b/internal/daemon/worker/worker_test.go index 0e8cc6edcc..c8ea966bca 100644 --- a/internal/daemon/worker/worker_test.go +++ b/internal/daemon/worker/worker_test.go @@ -9,6 +9,9 @@ import ( "crypto/rand" "crypto/tls" "crypto/x509" + "fmt" + "net" + "os" "sync" "testing" "time" @@ -33,11 +36,34 @@ import ( "github.com/hashicorp/nodeenrollment/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" ) func TestWorkerNew(t *testing.T) { + knownHostsPath := t.TempDir() + "/known_hosts" + nonexistantKnownHostsPath := t.TempDir() + "/does_not_exist" + corruptedKnownHostsPath := t.TempDir() + "/corrupted_known_hosts" + + file, err := os.Create(knownHostsPath) + require.NoError(t, err) + defer file.Close() + + signer, err := ssh.NewSignerFromKey(ed25519.NewKeyFromSeed([]byte("foobfoobfoobfoobfoobfoobfoobfoob"))) + require.NoError(t, err) + line := fmt.Sprintf("::1 %s", string(ssh.MarshalAuthorizedKey(signer.PublicKey()))) + _, err = file.WriteString(line) + + require.NoError(t, err) + + corruptFile, err := os.Create(corruptedKnownHostsPath) + require.NoError(t, err) + defer corruptFile.Close() + + _, err = corruptFile.WriteString("this is not valid known hosts content") + require.NoError(t, err) + tests := []struct { name string in *Config @@ -190,6 +216,87 @@ func TestWorkerNew(t *testing.T) { assert.Equal(t, wpbs.UnimplementedHostServiceServer{}, w.HostServiceServer) }, }, + { + name: "valid with no known hosts path", + in: &Config{ + Server: &base.Server{ + Listeners: []*base.ServerListener{ + {Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}}, + }, + }, + RawConfig: &config.Config{ + SharedConfig: &configutil.SharedConfig{ + DisableMlock: true, + }, + }, + }, + expErr: false, + assertions: func(t *testing.T, w *Worker) { + assert.Nil(t, w.SshKnownHostsCallback.Load()) + }, + }, + { + name: "valid known hosts path", + in: &Config{ + Server: &base.Server{ + Listeners: []*base.ServerListener{ + {Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}}, + }, + }, + RawConfig: &config.Config{ + Worker: &config.Worker{ + SshKnownHostsPath: knownHostsPath, + }, + SharedConfig: &configutil.SharedConfig{ + DisableMlock: true, + }, + }, + }, + expErr: false, + assertions: func(t *testing.T, w *Worker) { + assert.NotNil(t, w.SshKnownHostsCallback.Load()) + }, + }, + { + name: "invalid known hosts path", + in: &Config{ + Server: &base.Server{ + Listeners: []*base.ServerListener{ + {Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}}, + }, + }, + RawConfig: &config.Config{ + Worker: &config.Worker{ + SshKnownHostsPath: nonexistantKnownHostsPath, + }, + SharedConfig: &configutil.SharedConfig{ + DisableMlock: true, + }, + }, + }, + expErr: true, + expErrMsg: "no such file or directory", + }, + { + name: "corrupted known hosts file", + in: &Config{ + Server: &base.Server{ + Listeners: []*base.ServerListener{ + {Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}}, + }, + }, + RawConfig: &config.Config{ + Worker: &config.Worker{ + SshKnownHostsPath: corruptedKnownHostsPath, + }, + SharedConfig: &configutil.SharedConfig{ + DisableMlock: true, + }, + }, + }, + expErr: true, + expErrMsg: "illegal base64 data at input byte", + }, } for _, tt := range tests { @@ -211,7 +318,7 @@ func TestWorkerNew(t *testing.T) { w, err := New(context.Background(), tt.in) if tt.expErr { - require.EqualError(t, err, tt.expErrMsg) + require.ErrorContains(t, err, tt.expErrMsg) require.Nil(t, w) return } @@ -225,6 +332,19 @@ func TestWorkerNew(t *testing.T) { } func TestWorkerReload(t *testing.T) { + knownHostsPath := t.TempDir() + "/known_hosts" + file, err := os.Create(knownHostsPath) + require.NoError(t, err) + defer file.Close() + + signer, err := ssh.NewSignerFromKey(ed25519.NewKeyFromSeed([]byte("foobfoobfoobfoobfoobfoobfoobfoob"))) + require.NoError(t, err) + + dummyAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22} + line := fmt.Sprintf("github.com %s", string(ssh.MarshalAuthorizedKey(signer.PublicKey()))) + _, err = file.WriteString(line) + require.NoError(t, err) + t.Run("default config is the same as the reload config", func(t *testing.T) { require, assert := require.New(t), assert.New(t) cfg := &Config{ @@ -254,6 +374,7 @@ func TestWorkerReload(t *testing.T) { assert.Equal(int64(common.DefaultSessionInfoTimeout), w.sessionInfoCallTimeoutDuration.Load()) assert.Equal(int64(server.DefaultLiveness), w.getDownstreamWorkersTimeoutDuration.Load()) + assert.Nil(w.SshKnownHostsCallback.Load()) w.Reload(context.Background(), cfg.RawConfig) @@ -266,6 +387,7 @@ func TestWorkerReload(t *testing.T) { assert.Equal(int64(common.DefaultSessionInfoTimeout), w.sessionInfoCallTimeoutDuration.Load()) assert.Equal(int64(server.DefaultLiveness), w.getDownstreamWorkersTimeoutDuration.Load()) + assert.Nil(w.SshKnownHostsCallback.Load()) }) t.Run("new config is the same as the reload config", func(t *testing.T) { @@ -286,6 +408,7 @@ func TestWorkerReload(t *testing.T) { SuccessfulControllerRPCGracePeriodDuration: 5 * time.Second, ControllerRPCCallTimeoutDuration: 10 * time.Second, GetDownstreamWorkersTimeoutDuration: 20 * time.Second, + SshKnownHostsPath: knownHostsPath, }, }, } @@ -301,6 +424,10 @@ func TestWorkerReload(t *testing.T) { assert.Equal(int64(10*time.Second), w.sessionInfoCallTimeoutDuration.Load()) assert.Equal(int64(20*time.Second), w.getDownstreamWorkersTimeoutDuration.Load()) + cb := w.SshKnownHostsCallback.Load() + require.NotNil(cb) + err = (*cb)("github.com:22", dummyAddr, signer.PublicKey()) + assert.NoError(err) w.Reload(context.Background(), cfg.RawConfig) @@ -313,6 +440,89 @@ func TestWorkerReload(t *testing.T) { assert.Equal(int64(10*time.Second), w.sessionInfoCallTimeoutDuration.Load()) assert.Equal(int64(20*time.Second), w.getDownstreamWorkersTimeoutDuration.Load()) + cb = w.SshKnownHostsCallback.Load() + require.NotNil(cb) + err = (*cb)("github.com:22", dummyAddr, signer.PublicKey()) + assert.NoError(err) + }) + + t.Run("new config is different", func(t *testing.T) { + require, assert := require.New(t), assert.New(t) + cfg := &Config{ + Server: &base.Server{ + Logger: hclog.Default(), + Eventer: &event.Eventer{}, + Listeners: []*base.ServerListener{ + {Config: &listenerutil.ListenerConfig{Purpose: []string{"api"}}}, + {Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}}, + {Config: &listenerutil.ListenerConfig{Purpose: []string{"cluster"}}}, + }, + }, + RawConfig: &config.Config{ + SharedConfig: &configutil.SharedConfig{DisableMlock: true}, + Worker: &config.Worker{ + SuccessfulControllerRPCGracePeriodDuration: 5 * time.Second, + ControllerRPCCallTimeoutDuration: 10 * time.Second, + GetDownstreamWorkersTimeoutDuration: 20 * time.Second, + SshKnownHostsPath: knownHostsPath, + }, + }, + } + w, err := New(context.Background(), cfg) + require.NoError(err) + + assert.Equal(int64(5*time.Second), w.successfulRoutingInfoGracePeriod.Load()) + assert.Equal(int64(5*time.Second), w.successfulSessionInfoGracePeriod.Load()) + assert.Equal(w.successfulRoutingInfoGracePeriod.Load(), session.CloseCallTimeout.Load()) + + assert.Equal(int64(10*time.Second), w.routingInfoCallTimeoutDuration.Load()) + assert.Equal(int64(10*time.Second), w.statisticsCallTimeoutDuration.Load()) + assert.Equal(int64(10*time.Second), w.sessionInfoCallTimeoutDuration.Load()) + + assert.Equal(int64(20*time.Second), w.getDownstreamWorkersTimeoutDuration.Load()) + cb := w.SshKnownHostsCallback.Load() + require.NotNil(cb) + err = (*cb)("github.com:22", dummyAddr, signer.PublicKey()) + assert.NoError(err) + + // Update the config with new values + newKnownHostsFile := t.TempDir() + "/new_known_hosts" + newFile, err := os.Create(newKnownHostsFile) + require.NoError(err) + defer newFile.Close() + + newSigner, err := ssh.NewSignerFromKey(ed25519.NewKeyFromSeed([]byte("noobnoobnoobnoobnoobnoobnoobnoob"))) + require.NoError(err) + + line := fmt.Sprintf("github.com %s", string(ssh.MarshalAuthorizedKey(newSigner.PublicKey()))) + _, err = newFile.WriteString(line) + require.NoError(err) + + cfg.RawConfig.Worker.SuccessfulControllerRPCGracePeriodDuration = 30 * time.Second + cfg.RawConfig.Worker.ControllerRPCCallTimeoutDuration = 35 * time.Second + cfg.RawConfig.Worker.GetDownstreamWorkersTimeoutDuration = 40 * time.Second + cfg.RawConfig.Worker.SshKnownHostsPath = newKnownHostsFile + + w.Reload(context.Background(), cfg.RawConfig) + + assert.Equal(int64(30*time.Second), w.successfulRoutingInfoGracePeriod.Load()) + assert.Equal(int64(30*time.Second), w.successfulSessionInfoGracePeriod.Load()) + assert.Equal(w.successfulRoutingInfoGracePeriod.Load(), session.CloseCallTimeout.Load()) + + assert.Equal(int64(35*time.Second), w.routingInfoCallTimeoutDuration.Load()) + assert.Equal(int64(35*time.Second), w.statisticsCallTimeoutDuration.Load()) + assert.Equal(int64(35*time.Second), w.sessionInfoCallTimeoutDuration.Load()) + + assert.Equal(int64(40*time.Second), w.getDownstreamWorkersTimeoutDuration.Load()) + cb = w.SshKnownHostsCallback.Load() + require.NotNil(cb) + + // Old signer should fail + err = (*cb)("github.com:22", dummyAddr, signer.PublicKey()) + assert.Error(err) + // New signer should work + err = (*cb)("github.com:22", dummyAddr, newSigner.PublicKey()) + assert.NoError(err) }) }