feat(ssh): support ssh known hosts file (#6263)

pull/6267/head
Louis Ruch 6 months ago committed by GitHub
parent 6533c8bdb3
commit 4bae4f23f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -113,6 +113,7 @@ type Command struct {
flagWorkerAuthCaCertificateLifetime time.Duration
flagWorkerAuthDebuggingEnabled bool
flagWorkerRecordingStorageDir string
flagSshKnownHostsPath string
flagWorkerRecordingStorageMinimumAvailableCapacity string
flagBsrKey string
}
@ -427,6 +428,12 @@ func (c *Command) Flags() *base.FlagSets {
Usage: "Specifies the directory to store worker session recordings in dev mode. If not provided a temp directory will be created. Session recording is an Enterprise-only feature.",
})
f.StringVar(&base.StringVar{
Name: "worker-ssh-known-hosts-path",
Target: &c.flagSshKnownHostsPath,
Usage: "Specifies the path of the known_hosts file to be used by the worker for SSH host key verification of an SSH target in dev mode. SSH targets and SSH credential injection are Enterprise-only features.",
})
f.StringVar(&base.StringVar{
Name: "worker-recording-storage-minimum-available-capacity",
Target: &c.flagWorkerRecordingStorageMinimumAvailableCapacity,
@ -533,6 +540,7 @@ func (c *Command) Run(args []string) int {
c.Config.Plugins.ExecutionDir = c.flagPluginExecutionDir
if !c.flagControllerOnly {
c.Config.Worker.SshKnownHostsPath = c.flagSshKnownHostsPath
c.Config.Worker.AuthStoragePath = c.flagWorkerAuthStorageDir
c.Config.Worker.RecordingStoragePath = c.flagWorkerRecordingStorageDir
c.Config.Worker.RecordingStorageMinimumAvailableCapacity = c.flagWorkerRecordingStorageMinimumAvailableCapacity

@ -393,6 +393,11 @@ type Worker struct {
// they are sync'ed to the corresponding storage bucket. The path must already exist.
RecordingStoragePath string `hcl:"recording_storage_path"`
// SshKnownHostsPath represents the location of the known_hosts file to be used by the worker
// for SSH host key verification when connecting to ssh targets. The path must already exist.
// If not provided the worker will skip host key verification.
SshKnownHostsPath string `hcl:"ssh_known_hosts_path"`
// RecordingStorageMinimumAvailableCapacity represents the minimum amount of available
// disk space a worker needs in the path defined by RecordingStoragePath for processing
// sessions with recording enabled. The expected input value for this field is a

@ -280,7 +280,13 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "error getting decryption function")
event.WriteError(ctx, op, err)
}
runProxy, err := handleProxyFn(ctx, ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx, w.recorderManager, proxyHandlers.WithLogger(w.logger))
handlerOpts := []proxyHandlers.Option{proxyHandlers.WithLogger(w.logger)}
if cb := w.SshKnownHostsCallback.Load(); cb != nil {
handlerOpts = append(handlerOpts, proxyHandlers.WithSshHostKeyCallback(*cb))
}
runProxy, err := handleProxyFn(ctx, ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx, w.recorderManager, handlerOpts...)
if err != nil {
conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "unable to setup proxying")

@ -8,6 +8,7 @@ import (
serverpb "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/go-hclog"
"golang.org/x/crypto/ssh"
)
// Option - how Options are passed as arguments.
@ -30,6 +31,7 @@ type Options struct {
WithTestKdcAddress string
WithTestKerberosServerHostname string
WithLogger hclog.Logger
WithSshHostKeyCallback ssh.HostKeyCallback
}
func getDefaultOptions() Options {
@ -87,3 +89,11 @@ func WithLogger(l hclog.Logger) Option {
o.WithLogger = l
}
}
// WithSshHostKeyCallback allows specifying a ssh.HostKeyCallback function
// to be used for host key verification.
func WithSshHostKeyCallback(with ssh.HostKeyCallback) Option {
return func(o *Options) {
o.WithSshHostKeyCallback = with
}
}

@ -4,6 +4,7 @@
package proxy
import (
"crypto/ed25519"
"net"
"reflect"
"runtime"
@ -11,6 +12,7 @@ import (
serverpb "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ssh"
)
func Test_GetOpts(t *testing.T) {
@ -63,4 +65,15 @@ func Test_GetOpts(t *testing.T) {
testOpts.WithTestKerberosServerHostname = testKerberosServerHostname
assert.Equal(opts, testOpts)
})
t.Run("WithSshHostKeyCallback", func(t *testing.T) {
assert := assert.New(t)
opts := getDefaultOptions()
assert.Nil(opts.WithSshHostKeyCallback)
signer, err := ssh.NewSignerFromKey(ed25519.NewKeyFromSeed([]byte("foobfoobfoobfoobfoobfoobfoobfoob")))
assert.NoError(err)
opts = GetOpts(WithSshHostKeyCallback(ssh.FixedHostKey(signer.PublicKey())))
assert.NotNil(opts.WithSshHostKeyCallback)
})
}

@ -49,6 +49,8 @@ import (
"github.com/mr-tron/base58"
"github.com/prometheus/client_golang/prometheus"
ua "go.uber.org/atomic"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
"google.golang.org/grpc"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/protobuf/proto"
@ -230,6 +232,10 @@ type Worker struct {
downstreamConnManager *cluster.DownstreamManager
HostServiceServer wpbs.HostServiceServer
// SshKnownHostsCallback is used to provide a ssh.HostKeyCallback for SSH host key verification
// when connecting to an SSH target. This is an atomic because it can be updated at runtime via SIGHUP.
SshKnownHostsCallback atomic.Pointer[ssh.HostKeyCallback]
}
func New(ctx context.Context, conf *Config) (*Worker, error) {
@ -286,6 +292,14 @@ func New(ctx context.Context, conf *Config) (*Worker, error) {
w.localStorageState.Store(server.NotConfiguredLocalStorageState)
}
if w.conf.RawConfig.Worker.SshKnownHostsPath != "" {
cb, err := knownhosts.New(w.conf.RawConfig.Worker.SshKnownHostsPath)
if err != nil {
return nil, fmt.Errorf("error loading ssh known hosts file: %w", err)
}
w.SshKnownHostsCallback.Store(&cb)
}
pluginLogger, err := event.NewHclogLogger(ctx, w.conf.Server.Eventer)
if err != nil {
return nil, fmt.Errorf("error creating plugin logger: %w", err)
@ -509,6 +523,19 @@ func (w *Worker) Reload(ctx context.Context, newConf *config.Config) {
default:
w.getDownstreamWorkersTimeoutDuration.Store(int64(newConf.Worker.GetDownstreamWorkersTimeoutDuration))
}
switch newConf.Worker.SshKnownHostsPath {
case "":
w.SshKnownHostsCallback.Store(nil)
default:
cb, err := knownhosts.New(newConf.Worker.SshKnownHostsPath)
if err != nil {
event.WriteError(w.baseContext, op, fmt.Errorf("error loading ssh known hosts file: %w", err))
break
}
w.SshKnownHostsCallback.Store(&cb)
}
// See comment about this in worker.go
session.CloseCallTimeout.Store(w.successfulRoutingInfoGracePeriod.Load())

@ -9,6 +9,9 @@ import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"sync"
"testing"
"time"
@ -33,11 +36,34 @@ import (
"github.com/hashicorp/nodeenrollment/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestWorkerNew(t *testing.T) {
knownHostsPath := t.TempDir() + "/known_hosts"
nonexistantKnownHostsPath := t.TempDir() + "/does_not_exist"
corruptedKnownHostsPath := t.TempDir() + "/corrupted_known_hosts"
file, err := os.Create(knownHostsPath)
require.NoError(t, err)
defer file.Close()
signer, err := ssh.NewSignerFromKey(ed25519.NewKeyFromSeed([]byte("foobfoobfoobfoobfoobfoobfoobfoob")))
require.NoError(t, err)
line := fmt.Sprintf("::1 %s", string(ssh.MarshalAuthorizedKey(signer.PublicKey())))
_, err = file.WriteString(line)
require.NoError(t, err)
corruptFile, err := os.Create(corruptedKnownHostsPath)
require.NoError(t, err)
defer corruptFile.Close()
_, err = corruptFile.WriteString("this is not valid known hosts content")
require.NoError(t, err)
tests := []struct {
name string
in *Config
@ -190,6 +216,87 @@ func TestWorkerNew(t *testing.T) {
assert.Equal(t, wpbs.UnimplementedHostServiceServer{}, w.HostServiceServer)
},
},
{
name: "valid with no known hosts path",
in: &Config{
Server: &base.Server{
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
},
},
RawConfig: &config.Config{
SharedConfig: &configutil.SharedConfig{
DisableMlock: true,
},
},
},
expErr: false,
assertions: func(t *testing.T, w *Worker) {
assert.Nil(t, w.SshKnownHostsCallback.Load())
},
},
{
name: "valid known hosts path",
in: &Config{
Server: &base.Server{
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
},
},
RawConfig: &config.Config{
Worker: &config.Worker{
SshKnownHostsPath: knownHostsPath,
},
SharedConfig: &configutil.SharedConfig{
DisableMlock: true,
},
},
},
expErr: false,
assertions: func(t *testing.T, w *Worker) {
assert.NotNil(t, w.SshKnownHostsCallback.Load())
},
},
{
name: "invalid known hosts path",
in: &Config{
Server: &base.Server{
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
},
},
RawConfig: &config.Config{
Worker: &config.Worker{
SshKnownHostsPath: nonexistantKnownHostsPath,
},
SharedConfig: &configutil.SharedConfig{
DisableMlock: true,
},
},
},
expErr: true,
expErrMsg: "no such file or directory",
},
{
name: "corrupted known hosts file",
in: &Config{
Server: &base.Server{
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
},
},
RawConfig: &config.Config{
Worker: &config.Worker{
SshKnownHostsPath: corruptedKnownHostsPath,
},
SharedConfig: &configutil.SharedConfig{
DisableMlock: true,
},
},
},
expErr: true,
expErrMsg: "illegal base64 data at input byte",
},
}
for _, tt := range tests {
@ -211,7 +318,7 @@ func TestWorkerNew(t *testing.T) {
w, err := New(context.Background(), tt.in)
if tt.expErr {
require.EqualError(t, err, tt.expErrMsg)
require.ErrorContains(t, err, tt.expErrMsg)
require.Nil(t, w)
return
}
@ -225,6 +332,19 @@ func TestWorkerNew(t *testing.T) {
}
func TestWorkerReload(t *testing.T) {
knownHostsPath := t.TempDir() + "/known_hosts"
file, err := os.Create(knownHostsPath)
require.NoError(t, err)
defer file.Close()
signer, err := ssh.NewSignerFromKey(ed25519.NewKeyFromSeed([]byte("foobfoobfoobfoobfoobfoobfoobfoob")))
require.NoError(t, err)
dummyAddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}
line := fmt.Sprintf("github.com %s", string(ssh.MarshalAuthorizedKey(signer.PublicKey())))
_, err = file.WriteString(line)
require.NoError(t, err)
t.Run("default config is the same as the reload config", func(t *testing.T) {
require, assert := require.New(t), assert.New(t)
cfg := &Config{
@ -254,6 +374,7 @@ func TestWorkerReload(t *testing.T) {
assert.Equal(int64(common.DefaultSessionInfoTimeout), w.sessionInfoCallTimeoutDuration.Load())
assert.Equal(int64(server.DefaultLiveness), w.getDownstreamWorkersTimeoutDuration.Load())
assert.Nil(w.SshKnownHostsCallback.Load())
w.Reload(context.Background(), cfg.RawConfig)
@ -266,6 +387,7 @@ func TestWorkerReload(t *testing.T) {
assert.Equal(int64(common.DefaultSessionInfoTimeout), w.sessionInfoCallTimeoutDuration.Load())
assert.Equal(int64(server.DefaultLiveness), w.getDownstreamWorkersTimeoutDuration.Load())
assert.Nil(w.SshKnownHostsCallback.Load())
})
t.Run("new config is the same as the reload config", func(t *testing.T) {
@ -286,6 +408,7 @@ func TestWorkerReload(t *testing.T) {
SuccessfulControllerRPCGracePeriodDuration: 5 * time.Second,
ControllerRPCCallTimeoutDuration: 10 * time.Second,
GetDownstreamWorkersTimeoutDuration: 20 * time.Second,
SshKnownHostsPath: knownHostsPath,
},
},
}
@ -301,6 +424,10 @@ func TestWorkerReload(t *testing.T) {
assert.Equal(int64(10*time.Second), w.sessionInfoCallTimeoutDuration.Load())
assert.Equal(int64(20*time.Second), w.getDownstreamWorkersTimeoutDuration.Load())
cb := w.SshKnownHostsCallback.Load()
require.NotNil(cb)
err = (*cb)("github.com:22", dummyAddr, signer.PublicKey())
assert.NoError(err)
w.Reload(context.Background(), cfg.RawConfig)
@ -313,6 +440,89 @@ func TestWorkerReload(t *testing.T) {
assert.Equal(int64(10*time.Second), w.sessionInfoCallTimeoutDuration.Load())
assert.Equal(int64(20*time.Second), w.getDownstreamWorkersTimeoutDuration.Load())
cb = w.SshKnownHostsCallback.Load()
require.NotNil(cb)
err = (*cb)("github.com:22", dummyAddr, signer.PublicKey())
assert.NoError(err)
})
t.Run("new config is different", func(t *testing.T) {
require, assert := require.New(t), assert.New(t)
cfg := &Config{
Server: &base.Server{
Logger: hclog.Default(),
Eventer: &event.Eventer{},
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"api"}}},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"cluster"}}},
},
},
RawConfig: &config.Config{
SharedConfig: &configutil.SharedConfig{DisableMlock: true},
Worker: &config.Worker{
SuccessfulControllerRPCGracePeriodDuration: 5 * time.Second,
ControllerRPCCallTimeoutDuration: 10 * time.Second,
GetDownstreamWorkersTimeoutDuration: 20 * time.Second,
SshKnownHostsPath: knownHostsPath,
},
},
}
w, err := New(context.Background(), cfg)
require.NoError(err)
assert.Equal(int64(5*time.Second), w.successfulRoutingInfoGracePeriod.Load())
assert.Equal(int64(5*time.Second), w.successfulSessionInfoGracePeriod.Load())
assert.Equal(w.successfulRoutingInfoGracePeriod.Load(), session.CloseCallTimeout.Load())
assert.Equal(int64(10*time.Second), w.routingInfoCallTimeoutDuration.Load())
assert.Equal(int64(10*time.Second), w.statisticsCallTimeoutDuration.Load())
assert.Equal(int64(10*time.Second), w.sessionInfoCallTimeoutDuration.Load())
assert.Equal(int64(20*time.Second), w.getDownstreamWorkersTimeoutDuration.Load())
cb := w.SshKnownHostsCallback.Load()
require.NotNil(cb)
err = (*cb)("github.com:22", dummyAddr, signer.PublicKey())
assert.NoError(err)
// Update the config with new values
newKnownHostsFile := t.TempDir() + "/new_known_hosts"
newFile, err := os.Create(newKnownHostsFile)
require.NoError(err)
defer newFile.Close()
newSigner, err := ssh.NewSignerFromKey(ed25519.NewKeyFromSeed([]byte("noobnoobnoobnoobnoobnoobnoobnoob")))
require.NoError(err)
line := fmt.Sprintf("github.com %s", string(ssh.MarshalAuthorizedKey(newSigner.PublicKey())))
_, err = newFile.WriteString(line)
require.NoError(err)
cfg.RawConfig.Worker.SuccessfulControllerRPCGracePeriodDuration = 30 * time.Second
cfg.RawConfig.Worker.ControllerRPCCallTimeoutDuration = 35 * time.Second
cfg.RawConfig.Worker.GetDownstreamWorkersTimeoutDuration = 40 * time.Second
cfg.RawConfig.Worker.SshKnownHostsPath = newKnownHostsFile
w.Reload(context.Background(), cfg.RawConfig)
assert.Equal(int64(30*time.Second), w.successfulRoutingInfoGracePeriod.Load())
assert.Equal(int64(30*time.Second), w.successfulSessionInfoGracePeriod.Load())
assert.Equal(w.successfulRoutingInfoGracePeriod.Load(), session.CloseCallTimeout.Load())
assert.Equal(int64(35*time.Second), w.routingInfoCallTimeoutDuration.Load())
assert.Equal(int64(35*time.Second), w.statisticsCallTimeoutDuration.Load())
assert.Equal(int64(35*time.Second), w.sessionInfoCallTimeoutDuration.Load())
assert.Equal(int64(40*time.Second), w.getDownstreamWorkersTimeoutDuration.Load())
cb = w.SshKnownHostsCallback.Load()
require.NotNil(cb)
// Old signer should fail
err = (*cb)("github.com:22", dummyAddr, signer.PublicKey())
assert.Error(err)
// New signer should work
err = (*cb)("github.com:22", dummyAddr, newSigner.PublicKey())
assert.NoError(err)
})
}

Loading…
Cancel
Save