feat(clientproxy): automatically close boundary connect when proxy is no longer in use (#6232)

Co-authored-by: Andrew Gaffney <andrew@gaffney.cc>
pull/6250/head
dani 3 months ago committed by GitHub
parent d0050ad881
commit 08015e9d98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -39,6 +39,7 @@ type Options struct {
WithSkipSessionTeardown bool
withSessionTeardownTimeout time.Duration
withApiClient *api.Client
withInactivityTimeout time.Duration
}
// Option is a function that takes in an options struct and sets values or
@ -142,3 +143,12 @@ func WithApiClient(with *api.Client) Option {
return nil
}
}
// WithInactivityTimeout provides an optional duration after which a session
// with no active connections will be cancelled
func WithInactivityTimeout(with time.Duration) Option {
return func(o *Options) error {
o.withInactivityTimeout = with
return nil
}
}

@ -34,26 +34,28 @@ import (
const sessionCancelTimeout = 30 * time.Second
type ClientProxy struct {
tofuToken string
cachedListenerAddress *ua.String
connectionsLeft *atomic.Int32
connsLeftCh chan int32
callerConnectionsLeftCh chan int32
apiClient *api.Client
sessionAuthzData *targets.SessionAuthorizationData
createTime time.Time
expiration time.Time
ctx context.Context
cancel context.CancelFunc
transport *http.Transport
workerAddr string
listenAddrPort netip.AddrPort
listener *atomic.Value
listenerCloseOnce *sync.Once
clientTlsConf *tls.Config
connWg *sync.WaitGroup
started *atomic.Bool
skipSessionTeardown bool
tofuToken string
cachedListenerAddress *ua.String
connectionsLeft *atomic.Int32
activeConns *atomic.Int32
connsLeftCh chan int32
callerConnsLeftCh chan int32
apiClient *api.Client
sessionAuthzData *targets.SessionAuthorizationData
createTime time.Time
expiration time.Time
ctx context.Context
cancel context.CancelFunc
transport *http.Transport
workerAddr string
listenAddrPort netip.AddrPort
listener *atomic.Value
listenerCloseOnce *sync.Once
clientTlsConf *tls.Config
connWg *sync.WaitGroup
started *atomic.Bool
skipSessionTeardown bool
closeReason *atomic.Value
}
// New creates a new client proxy. The given context should be cancelable; once
@ -90,17 +92,19 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e
}
p := &ClientProxy{
cachedListenerAddress: ua.NewString(""),
connsLeftCh: make(chan int32),
connectionsLeft: new(atomic.Int32),
listener: new(atomic.Value),
listenerCloseOnce: new(sync.Once),
connWg: new(sync.WaitGroup),
listenAddrPort: opts.WithListenAddrPort,
callerConnectionsLeftCh: opts.WithConnectionsLeftCh,
started: new(atomic.Bool),
skipSessionTeardown: opts.WithSkipSessionTeardown,
apiClient: opts.withApiClient,
cachedListenerAddress: ua.NewString(""),
connsLeftCh: make(chan int32),
connectionsLeft: new(atomic.Int32),
activeConns: new(atomic.Int32),
listener: new(atomic.Value),
listenerCloseOnce: new(sync.Once),
connWg: new(sync.WaitGroup),
listenAddrPort: opts.WithListenAddrPort,
callerConnsLeftCh: opts.WithConnectionsLeftCh,
started: new(atomic.Bool),
skipSessionTeardown: opts.WithSkipSessionTeardown,
apiClient: opts.withApiClient,
closeReason: new(atomic.Value),
}
if opts.WithListener != nil {
@ -142,7 +146,7 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e
// We don't _rely_ on client-side timeout verification but this prevents us
// seeming to be ready for a connection that will immediately fail when we
// try to actually make it
p.ctx, p.cancel = context.WithDeadline(ctx, p.expiration)
p.ctx, p.cancel = context.WithDeadlineCause(ctx, p.expiration, fmt.Errorf("Session has expired"))
transport := cleanhttp.DefaultTransport()
transport.DisableKeepAlives = false
@ -212,6 +216,17 @@ func (p *ClientProxy) Start(opt ...Option) (retErr error) {
// Ensure closing the listener runs on any other return condition
defer listenerCloseFunc()
// automatically close the proxy when inactive
proxyAutoClose := time.AfterFunc(10*time.Minute, func() {
p.cancel()
p.setCloseReason("Inactivity timeout reached")
})
activeConnCh := make(chan int32)
activeConnFn := func(d int32) {
activeConnCh <- p.activeConns.Add(d)
}
fin := make(chan error, 10)
p.connWg.Add(1)
go func() {
@ -243,8 +258,10 @@ func (p *ClientProxy) Start(opt ...Option) (retErr error) {
return
}
}
activeConnFn(1)
p.connWg.Add(1)
go func() {
defer activeConnFn(-1)
defer listeningConn.Close()
defer p.connWg.Done()
wsConn, err := p.getWsConn(p.ctx)
@ -305,27 +322,40 @@ func (p *ClientProxy) Start(opt ...Option) (retErr error) {
}()
defer p.connWg.Done()
defer listenerCloseFunc()
for {
select {
case <-p.ctx.Done():
if err := context.Cause(p.ctx); !errors.Is(err, context.Canceled) {
p.setCloseReason(err.Error())
}
return
case connsLeft := <-p.connsLeftCh:
p.connectionsLeft.Store(connsLeft)
if p.callerConnectionsLeftCh != nil {
p.callerConnectionsLeftCh <- connsLeft
if p.callerConnsLeftCh != nil {
p.callerConnsLeftCh <- connsLeft
}
if connsLeft == 0 {
// Close the listener as we can't authorize any more
// connections
p.setCloseReason("No connections left in session")
return
}
case activeConns := <-activeConnCh:
switch {
case activeConns > 0:
// always stop the timer when a new connection is made,
// even if timeout opt is 0
proxyAutoClose.Stop()
case opts.withInactivityTimeout <= 0:
// no timeout was set, timer should not be reset for inactivity
case activeConns == 0:
proxyAutoClose.Reset(opts.withInactivityTimeout)
}
}
}
}()
p.connWg.Wait()
defer p.cancel()
{
// the go funcs are done, so we can safely close the chan and range over any errors
@ -367,6 +397,25 @@ func (p *ClientProxy) CloseSession(sessionTeardownTimeout time.Duration) error {
return nil
}
// CloseReason returns the reason why the proxy was closed, if the proxy closed
// itself. If the proxy is still running or the proxy was closed externally, an
// empty string is returned.
func (p *ClientProxy) CloseReason() string {
switch r := p.closeReason.Load().(type) {
case string:
return r
default:
return ""
}
}
// setCloseReason updates the reason the proxy closed from an empty string to the
// provided string. setCloseReason only accepts the first provided reason for
// closing, all other calls are ignored.
func (p *ClientProxy) setCloseReason(reason string) {
p.closeReason.CompareAndSwap(nil, reason)
}
// ListenerAddress returns the address of the client proxy listener. Because the
// listener is started with Start(), this could be called before listening
// occurs. To avoid returning until we have a valid value, pass a context;

@ -69,6 +69,7 @@ type Command struct {
flagUsername string
flagDbname string
flagMongoDbAuthenticationDatabase string
flagInactiveTimeout time.Duration
// HTTP
httpFlags
@ -220,6 +221,13 @@ func (c *Command) Flags() *base.FlagSets {
Usage: "Target scope name, if authorizing the session via scope parameters and target name. Mutually exclusive with -scope-id.",
})
f.DurationVar(&base.DurationVar{
Name: "inactive-timeout",
Target: &c.flagInactiveTimeout,
Completion: complete.PredictAnything,
Usage: "How long to wait between connections before closing the session. Increase this value if the proxy closes during long-running processes, or use -1 to disable the timeout.",
})
switch c.Func {
case "connect":
f.StringVar(&base.StringVar{
@ -508,11 +516,32 @@ func (c *Command) Run(args []string) (retCode int) {
clientProxyCloseCh := make(chan struct{})
connCountCloseCh := make(chan struct{})
if c.flagInactiveTimeout == 0 {
// no timeout was specified by the user, so use our defaults based on subcommand
switch c.Func {
case "connect":
// connect is when there is no subcommand specified, this case should
// have the most generous timeout
apiProxyOpts = append(apiProxyOpts, apiproxy.WithInactivityTimeout(30*time.Second))
case "rdp":
// rdp has a gui, so give the user a chance to click "reconnect"
apiProxyOpts = append(apiProxyOpts, apiproxy.WithInactivityTimeout(5*time.Second))
case "ssh":
// one second is probably enough for ssh
apiProxyOpts = append(apiProxyOpts, apiproxy.WithInactivityTimeout(time.Second))
default:
// for other protocols, give some extra leeway just in case
apiProxyOpts = append(apiProxyOpts, apiproxy.WithInactivityTimeout(3*time.Second))
}
} else {
apiProxyOpts = append(apiProxyOpts, apiproxy.WithInactivityTimeout(c.flagInactiveTimeout))
}
proxyError := new(atomic.Error)
go func() {
defer close(clientProxyCloseCh)
if err = clientProxy.Start(); err != nil {
c.proxyCancel()
defer c.proxyCancel()
if err = clientProxy.Start(apiProxyOpts...); err != nil {
proxyError.Store(err)
}
}()
@ -595,10 +624,8 @@ func (c *Command) Run(args []string) (retCode int) {
if c.execCmdReturnValue != nil {
// Don't print out in this case, so ensure we clear it
termInfo.Reason = ""
} else if time.Now().After(clientProxy.SessionExpiration()) {
termInfo.Reason = "Session has expired"
} else if clientProxy.ConnectionsLeft() == 0 {
termInfo.Reason = "No connections left in session"
} else if r := clientProxy.CloseReason(); r != "" {
termInfo.Reason = r
} else if err := proxyError.Load(); err != nil {
termInfo.Reason = "Error from proxy client: " + err.Error()
}
@ -825,10 +852,9 @@ func (c *Command) handleExec(clientProxy *apiproxy.ClientProxy, passthroughArgs
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmdExit := make(chan struct{})
if err := cmd.Run(); err != nil {
exitCode := 2
cmdError := func(err error) {
if exitError, ok := err.(*exec.ExitError); ok {
if exitError.Success() {
c.execCmdReturnValue.Store(0)
@ -841,8 +867,30 @@ func (c *Command) handleExec(clientProxy *apiproxy.ClientProxy, passthroughArgs
}
c.PrintCliError(fmt.Errorf("Failed to run command: %w", err))
c.execCmdReturnValue.Store(int32(exitCode))
c.execCmdReturnValue.Store(2)
return
}
c.execCmdReturnValue.Store(0)
go func() {
defer close(cmdExit)
if err := cmd.Start(); err != nil {
cmdError(err)
return
}
if err := cmd.Wait(); err != nil {
cmdError(err)
return
}
c.execCmdReturnValue.Store(0)
}()
for {
select {
case <-c.proxyCtx.Done():
// the proxy exited for some reason, end the cmd since connections are no longer possible
_ = endProcess(cmd.Process)
case <-cmdExit:
return
}
}
}

@ -0,0 +1,19 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !windows
package connect
import (
"os"
"syscall"
)
// endProcess gracefully ends the provided os process
func endProcess(p *os.Process) error {
if p == nil {
return nil
}
return p.Signal(syscall.SIGTERM)
}

@ -0,0 +1,18 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build windows
package connect
import (
"os"
)
// endProcess kills the provided os process
func endProcess(p *os.Process) error {
if p == nil {
return nil
}
return p.Kill()
}

@ -59,7 +59,7 @@ func (r *rdpFlags) buildArgs(c *Command, port, ip, addr string) []string {
case "mstsc.exe":
args = append(args, "/v", addr)
case "open":
args = append(args, "-n", "-W", fmt.Sprintf("rdp://full%saddress=s%s%s", "%20", "%3A", url.QueryEscape(addr)))
args = append(args, "-W", fmt.Sprintf("rdp://full%saddress=s%s%s", "%20", "%3A", url.QueryEscape(addr)))
}
return args
}

@ -19,6 +19,7 @@ import (
"github.com/hashicorp/boundary/internal/tests/helper"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
_ "github.com/hashicorp/boundary/internal/daemon/controller/handlers/targets/tcp"
)
@ -140,3 +141,99 @@ func TestConnectionsLeft(t *testing.T) {
// Wait to ensure cleanup and that the second-start logic works
wg.Wait()
}
func TestConnectionTimeout(t *testing.T) {
require := require.New(t)
logger := hclog.New(&hclog.LoggerOptions{
Name: t.Name(),
Level: hclog.Trace,
})
// Create controller and worker
conf, err := config.DevController()
require.NoError(err)
c1 := controller.NewTestController(t, &controller.TestControllerOpts{
Config: conf,
InitialResourcesSuffix: "1234567890",
Logger: logger.Named("c1"),
WorkerRPCGracePeriod: helper.DefaultControllerRPCGracePeriod,
})
helper.ExpectWorkers(t, c1)
w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{
WorkerAuthKms: c1.Config().WorkerAuthKms,
InitialUpstreams: c1.ClusterAddrs(),
Logger: logger.Named("w1"),
SuccessfulControllerRPCGracePeriodDuration: helper.DefaultControllerRPCGracePeriod,
Name: "w1",
})
helper.ExpectWorkers(t, c1, w1)
// Connect target
client := c1.Client()
client.SetToken(c1.Token().Token)
tcl := targets.NewClient(client)
tgt, err := tcl.Read(c1.Context(), "ttcp_1234567890")
require.NoError(err)
require.NotNil(tgt)
// Create test server, update default port on target
ts := helper.NewTestTcpServer(t)
require.NotNil(t, ts)
defer ts.Close()
var sessionConnsLimit int32 = 2
tgt = updateTargetForProxy(t, c1.Context(), tcl, tgt, ts.Port(), sessionConnsLimit, w1.Name())
// Authorize session to get authorization data
sess, err := tcl.AuthorizeSession(c1.Context(), tgt.Item.Id)
require.NoError(err)
sessAuthz, err := sess.GetSessionAuthorization()
require.NoError(err)
// Create a context we can cancel to stop the proxy, a channel for conns
// left, and a waitgroup to ensure cleanup
pxyCtx, pxyCancel := context.WithCancel(c1.Context())
defer pxyCancel()
wg := new(sync.WaitGroup)
pxy, err := proxy.New(pxyCtx, sessAuthz.AuthorizationToken)
require.NoError(err)
wg.Add(1)
done := atomic.NewBool(false)
go func() {
defer wg.Done()
require.NoError(pxy.Start(proxy.WithInactivityTimeout(time.Second)))
done.Store(true)
}()
addr := pxy.ListenerAddress(context.Background())
require.NotEmpty(addr)
addrPort, err := netip.ParseAddrPort(addr)
require.NoError(err)
echo := []byte("echo")
readBuf := make([]byte, len(echo))
conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort(addrPort))
require.NoError(err)
written, err := conn.Write(echo)
require.NoError(err)
require.Equal(written, len(echo))
read, err := conn.Read(readBuf)
require.NoError(err)
require.Equal(read, len(echo))
require.NoError(conn.Close())
start := time.Now()
for {
if done.Load() || time.Since(start) > time.Second*2 {
require.True(done.Load(), "proxy did not close itself within the expected time frame (2 seconds)")
break
}
time.Sleep(10 * time.Millisecond)
}
require.Equal("Inactivity timeout reached", pxy.CloseReason())
pxyCancel()
wg.Wait()
}

Loading…
Cancel
Save