SSH Proxy support (#2736)

* initial commit
pull/2768/head
Irena Rindos 3 years ago committed by GitHub
parent 1fb960575f
commit beb97ac9f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -273,7 +273,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
}
// Verify the protocol has a supported proxy before calling RequestAuthorizeConnection
handleProxyFn, err := proxyHandlers.GetHandler(endpointUrl.Scheme)
handleProxyFn, err := proxyHandlers.GetHandler(workerId, ci.GetProtocolContext())
if err != nil {
conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "unable to get proxy handler")
event.WriteError(ctx, op, err)
@ -306,7 +306,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
return
}
runProxy()
runProxy(ctx)
}, nil
}

@ -1,6 +1,8 @@
package proxy
import (
"net"
serverpb "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
)
@ -19,11 +21,13 @@ func GetOpts(opt ...Option) Options {
// Options = how options are represented
type Options struct {
WithInjectedApplicationCredentials []*serverpb.Credential
WithPostConnectionHook func(net.Conn)
}
func getDefaultOptions() Options {
return Options{
WithInjectedApplicationCredentials: nil,
WithPostConnectionHook: nil,
}
}
@ -34,3 +38,13 @@ func WithInjectedApplicationCredentials(creds []*serverpb.Credential) Option {
o.WithInjectedApplicationCredentials = creds
}
}
// WithPostConnectionHook provides a hook function to be called after a
// connection is established in a dialFunction. When a dialer accepts
// WithPostConnectionHook the passed in function should be called prior to any
// other blocking call.
func WithPostConnectionHook(fn func(net.Conn)) Option {
return func(o *Options) {
o.WithPostConnectionHook = fn
}
}

@ -1,6 +1,9 @@
package proxy
import (
"net"
"reflect"
"runtime"
"testing"
serverpb "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
@ -26,4 +29,17 @@ func Test_GetOpts(t *testing.T) {
testOpts.WithInjectedApplicationCredentials = []*serverpb.Credential{c}
assert.Equal(opts, testOpts)
})
t.Run("WithPostConnectionHook", func(t *testing.T) {
assert := assert.New(t)
testFn := func(net.Conn) {}
opts := GetOpts(WithPostConnectionHook(testFn))
testOpts := getDefaultOptions()
assert.NotEqual(opts, testOpts)
testOpts.WithPostConnectionHook = testFn
assert.Equal(
runtime.FuncForPC(reflect.ValueOf(opts.WithPostConnectionHook).Pointer()).Name(),
runtime.FuncForPC(reflect.ValueOf(testOpts.WithPostConnectionHook).Pointer()).Name(),
)
})
}

@ -6,10 +6,13 @@ import (
"net"
"sync"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
)
var (
TcpHandlerName = "tcp"
// handlers is the map of registered handlers
handlers sync.Map
@ -18,11 +21,16 @@ var (
// ErrProtocolAlreadyRegistered specifies the provided protocol has already been registered
ErrProtocolAlreadyRegistered = errors.New("proxy: protocol already registered")
// GetHandler returns the handler registered for the provided worker and
// protocolContext. If a protocol cannot be determined or the protocol is
// not registered nil, ErrUnknownProtocol is returned.
GetHandler = tcpOnly
)
// ProxyConnFn is called after the call to ConnectConnection on the cluster.
// ProxyConnFn blocks until the specific request that is being proxied is finished
type ProxyConnFn func()
type ProxyConnFn func(ctx context.Context)
// Handler is the type that all proxies need to implement to be called by the worker
// when a new client connection is created. If there is an error ProxyConnFn must
@ -39,10 +47,9 @@ func RegisterHandler(protocol string, handler Handler) error {
return nil
}
// GetHandler returns the handler registered for the provided protocol. If the protocol
// is not registered nil and ErrUnknownProtocol is returned.
func GetHandler(protocol string) (Handler, error) {
handler, ok := handlers.Load(protocol)
// tcpOnly returns only the TCP protocol.
func tcpOnly(string, proto.Message) (Handler, error) {
handler, ok := handlers.Load(TcpHandlerName)
if !ok {
return nil, ErrUnknownProtocol
}

@ -35,9 +35,8 @@ func TestRegisterHandler(t *testing.T) {
require.NoError(err)
}
func TestGetHandler(t *testing.T) {
func TestAlwaysTcpGetHandler(t *testing.T) {
assert, require := assert.New(t), require.New(t)
fn := func(context.Context, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) {
return nil, nil
}
@ -47,15 +46,13 @@ func TestGetHandler(t *testing.T) {
})
handlers = sync.Map{}
err := RegisterHandler("fn", fn)
require.NoError(err)
gotFn, err := GetHandler("fake")
_, err := tcpOnly("wid", nil)
require.Error(err)
assert.ErrorIs(err, ErrUnknownProtocol)
assert.Nil(gotFn)
gotFn, err = GetHandler("fn")
require.NoError(RegisterHandler("tcp", fn))
handler, err := tcpOnly("wid", nil)
require.NoError(err)
assert.NotNil(gotFn)
require.NotNil(handler)
}

@ -20,7 +20,7 @@ func directDialer(ctx context.Context, endpoint string, _ string, _ proto.Messag
if len(endpoint) == 0 {
return nil, errors.New(ctx, errors.InvalidParameter, op, "endpoint is empty")
}
d, err := NewProxyDialer(ctx, func() (net.Conn, error) {
d, err := NewProxyDialer(ctx, func(...Option) (net.Conn, error) {
remoteConn, err := net.Dial("tcp", endpoint)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
@ -52,12 +52,12 @@ func (p *proxyAddr) Port() uint32 {
// ProxyDialer dials downstream to eventually get to the target host.
type ProxyDialer struct {
dialFn func() (net.Conn, error)
dialFn func(...Option) (net.Conn, error)
latestAddr atomic.Pointer[proxyAddr]
}
// Returns a new proxy dialer using the provided function to get the net.Conn.
func NewProxyDialer(ctx context.Context, df func() (net.Conn, error)) (*ProxyDialer, error) {
func NewProxyDialer(ctx context.Context, df func(...Option) (net.Conn, error)) (*ProxyDialer, error) {
const op = "proxy.NewProxyDialer"
if df == nil {
return nil, errors.New(ctx, errors.InvalidParameter, op, "dialing function is nil")
@ -86,9 +86,9 @@ type portAndIpGetter interface {
// Dial uses the provided dial function to get a net.Conn and record its
// net.Addr information. The returned net.Addr should contain the information
// for the endpoint that is being proxied to.
func (d *ProxyDialer) Dial(ctx context.Context) (net.Conn, error) {
func (d *ProxyDialer) Dial(ctx context.Context, opt ...Option) (net.Conn, error) {
const op = "proxy.(*ProxyDialer).Dial"
c, err := d.dialFn()
c, err := d.dialFn(opt...)
if err != nil {
return nil, err
}

@ -15,7 +15,7 @@ func TestNewProxyDialer(t *testing.T) {
assert.Error(t, err)
assert.Nil(t, d)
d, err = NewProxyDialer(context.Background(), func() (net.Conn, error) {
d, err = NewProxyDialer(context.Background(), func(...Option) (net.Conn, error) {
c, _ := net.Pipe()
return c, nil
})
@ -34,9 +34,10 @@ func TestProxyDialer(t *testing.T) {
t.Run("Dial error", func(t *testing.T) {
expectedErr := errors.New("test error")
d, err := NewProxyDialer(ctx, func() (net.Conn, error) {
d, err := NewProxyDialer(ctx, func(...Option) (net.Conn, error) {
return nil, expectedErr
})
require.NoError(t, err)
assert.Nil(t, d.LastConnectionAddr())
badC, err := d.Dial(ctx)
require.Error(t, err)
@ -45,9 +46,10 @@ func TestProxyDialer(t *testing.T) {
})
t.Run("Successful Dial", func(t *testing.T) {
d, err := NewProxyDialer(ctx, func() (net.Conn, error) {
d, err := NewProxyDialer(ctx, func(...Option) (net.Conn, error) {
return net.Dial("tcp", l.Addr().String())
})
require.NoError(t, err)
assert.Nil(t, d.LastConnectionAddr())
c, err := d.Dial(ctx)
require.NoError(t, err)

@ -12,20 +12,18 @@ import (
)
func init() {
err := proxy.RegisterHandler("tcp", handleProxy)
err := proxy.RegisterHandler(proxy.TcpHandlerName, handleProxy)
if err != nil {
panic(err)
}
}
// handleProxy creates a tcp proxy between the incoming websocket conn and the
// connection it creates with the remote endpoint. handleTcpProxyV1 sets the connectionId
// as connected in the repository.
// handleProxy creates a tcp proxy between the incoming conn and the
// connection created by the ProxyDialer.
//
// handleProxy blocks until an error (EOF on happy path) is received on either
// connection.
//
// All options are ignored.
// handleProxy returns a ProxyConnFn which starts the copy between the
// connections and blocks until an error (EOF on happy path) is received on
// either connection.
func handleProxy(ctx context.Context, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any) (proxy.ProxyConnFn, error) {
const op = "tcp.HandleProxy"
switch {
@ -41,7 +39,7 @@ func handleProxy(ctx context.Context, conn net.Conn, out *proxy.ProxyDialer, con
return nil, err
}
return func() {
return func(ctx context.Context) {
connWg := new(sync.WaitGroup)
connWg.Add(2)
go func() {

@ -29,7 +29,7 @@ func TestHandleProxy_Errors(t *testing.T) {
t.Cleanup(func() {
l.Close()
})
dialer, err := proxy.NewProxyDialer(context.Background(), func() (net.Conn, error) {
dialer, err := proxy.NewProxyDialer(context.Background(), func(...proxy.Option) (net.Conn, error) {
return net.Dial("tcp", l.Addr().String())
})
require.NoError(t, err)
@ -156,7 +156,7 @@ func TestHandleTcpProxyV1(t *testing.T) {
resp, _, err := s.RequestAuthorizeConnection(ctx, "workerid", connCancelFn)
require.NoError(err)
tDial, err := proxy.NewProxyDialer(ctx, func() (net.Conn, error) {
tDial, err := proxy.NewProxyDialer(ctx, func(...proxy.Option) (net.Conn, error) {
return net.Dial("tcp", l.Addr().String())
})
require.NoError(err)
@ -170,7 +170,7 @@ func TestHandleTcpProxyV1(t *testing.T) {
// https://pkg.go.dev/testing#T.FailNow
require.NoError(err)
})
fn()
fn(context.Background())
}()
// wait for HandleTcpProxyV1 to dial endpoint

Loading…
Cancel
Save