diff --git a/api/proxy/option.go b/api/proxy/option.go index 3f37d89c33..936f053fb1 100644 --- a/api/proxy/option.go +++ b/api/proxy/option.go @@ -39,6 +39,7 @@ type Options struct { WithSkipSessionTeardown bool withSessionTeardownTimeout time.Duration withApiClient *api.Client + withInactivityTimeout time.Duration } // Option is a function that takes in an options struct and sets values or @@ -142,3 +143,12 @@ func WithApiClient(with *api.Client) Option { return nil } } + +// WithInactivityTimeout provides an optional duration after which a session +// with no active connections will be cancelled +func WithInactivityTimeout(with time.Duration) Option { + return func(o *Options) error { + o.withInactivityTimeout = with + return nil + } +} diff --git a/api/proxy/proxy.go b/api/proxy/proxy.go index 8aa89bac9a..99f16bb4f5 100644 --- a/api/proxy/proxy.go +++ b/api/proxy/proxy.go @@ -34,26 +34,28 @@ import ( const sessionCancelTimeout = 30 * time.Second type ClientProxy struct { - tofuToken string - cachedListenerAddress *ua.String - connectionsLeft *atomic.Int32 - connsLeftCh chan int32 - callerConnectionsLeftCh chan int32 - apiClient *api.Client - sessionAuthzData *targets.SessionAuthorizationData - createTime time.Time - expiration time.Time - ctx context.Context - cancel context.CancelFunc - transport *http.Transport - workerAddr string - listenAddrPort netip.AddrPort - listener *atomic.Value - listenerCloseOnce *sync.Once - clientTlsConf *tls.Config - connWg *sync.WaitGroup - started *atomic.Bool - skipSessionTeardown bool + tofuToken string + cachedListenerAddress *ua.String + connectionsLeft *atomic.Int32 + activeConns *atomic.Int32 + connsLeftCh chan int32 + callerConnsLeftCh chan int32 + apiClient *api.Client + sessionAuthzData *targets.SessionAuthorizationData + createTime time.Time + expiration time.Time + ctx context.Context + cancel context.CancelFunc + transport *http.Transport + workerAddr string + listenAddrPort netip.AddrPort + listener *atomic.Value + listenerCloseOnce *sync.Once + clientTlsConf *tls.Config + connWg *sync.WaitGroup + started *atomic.Bool + skipSessionTeardown bool + closeReason *atomic.Value } // New creates a new client proxy. The given context should be cancelable; once @@ -90,17 +92,19 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e } p := &ClientProxy{ - cachedListenerAddress: ua.NewString(""), - connsLeftCh: make(chan int32), - connectionsLeft: new(atomic.Int32), - listener: new(atomic.Value), - listenerCloseOnce: new(sync.Once), - connWg: new(sync.WaitGroup), - listenAddrPort: opts.WithListenAddrPort, - callerConnectionsLeftCh: opts.WithConnectionsLeftCh, - started: new(atomic.Bool), - skipSessionTeardown: opts.WithSkipSessionTeardown, - apiClient: opts.withApiClient, + cachedListenerAddress: ua.NewString(""), + connsLeftCh: make(chan int32), + connectionsLeft: new(atomic.Int32), + activeConns: new(atomic.Int32), + listener: new(atomic.Value), + listenerCloseOnce: new(sync.Once), + connWg: new(sync.WaitGroup), + listenAddrPort: opts.WithListenAddrPort, + callerConnsLeftCh: opts.WithConnectionsLeftCh, + started: new(atomic.Bool), + skipSessionTeardown: opts.WithSkipSessionTeardown, + apiClient: opts.withApiClient, + closeReason: new(atomic.Value), } if opts.WithListener != nil { @@ -142,7 +146,7 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e // We don't _rely_ on client-side timeout verification but this prevents us // seeming to be ready for a connection that will immediately fail when we // try to actually make it - p.ctx, p.cancel = context.WithDeadline(ctx, p.expiration) + p.ctx, p.cancel = context.WithDeadlineCause(ctx, p.expiration, fmt.Errorf("Session has expired")) transport := cleanhttp.DefaultTransport() transport.DisableKeepAlives = false @@ -212,6 +216,17 @@ func (p *ClientProxy) Start(opt ...Option) (retErr error) { // Ensure closing the listener runs on any other return condition defer listenerCloseFunc() + // automatically close the proxy when inactive + proxyAutoClose := time.AfterFunc(10*time.Minute, func() { + p.cancel() + p.setCloseReason("Inactivity timeout reached") + }) + + activeConnCh := make(chan int32) + activeConnFn := func(d int32) { + activeConnCh <- p.activeConns.Add(d) + } + fin := make(chan error, 10) p.connWg.Add(1) go func() { @@ -243,8 +258,10 @@ func (p *ClientProxy) Start(opt ...Option) (retErr error) { return } } + activeConnFn(1) p.connWg.Add(1) go func() { + defer activeConnFn(-1) defer listeningConn.Close() defer p.connWg.Done() wsConn, err := p.getWsConn(p.ctx) @@ -305,27 +322,40 @@ func (p *ClientProxy) Start(opt ...Option) (retErr error) { }() defer p.connWg.Done() defer listenerCloseFunc() - for { select { case <-p.ctx.Done(): + if err := context.Cause(p.ctx); !errors.Is(err, context.Canceled) { + p.setCloseReason(err.Error()) + } return case connsLeft := <-p.connsLeftCh: p.connectionsLeft.Store(connsLeft) - if p.callerConnectionsLeftCh != nil { - p.callerConnectionsLeftCh <- connsLeft + if p.callerConnsLeftCh != nil { + p.callerConnsLeftCh <- connsLeft } if connsLeft == 0 { // Close the listener as we can't authorize any more // connections + p.setCloseReason("No connections left in session") return } + case activeConns := <-activeConnCh: + switch { + case activeConns > 0: + // always stop the timer when a new connection is made, + // even if timeout opt is 0 + proxyAutoClose.Stop() + case opts.withInactivityTimeout <= 0: + // no timeout was set, timer should not be reset for inactivity + case activeConns == 0: + proxyAutoClose.Reset(opts.withInactivityTimeout) + } } } }() p.connWg.Wait() - defer p.cancel() { // the go funcs are done, so we can safely close the chan and range over any errors @@ -367,6 +397,25 @@ func (p *ClientProxy) CloseSession(sessionTeardownTimeout time.Duration) error { return nil } +// CloseReason returns the reason why the proxy was closed, if the proxy closed +// itself. If the proxy is still running or the proxy was closed externally, an +// empty string is returned. +func (p *ClientProxy) CloseReason() string { + switch r := p.closeReason.Load().(type) { + case string: + return r + default: + return "" + } +} + +// setCloseReason updates the reason the proxy closed from an empty string to the +// provided string. setCloseReason only accepts the first provided reason for +// closing, all other calls are ignored. +func (p *ClientProxy) setCloseReason(reason string) { + p.closeReason.CompareAndSwap(nil, reason) +} + // ListenerAddress returns the address of the client proxy listener. Because the // listener is started with Start(), this could be called before listening // occurs. To avoid returning until we have a valid value, pass a context; diff --git a/internal/cmd/commands/connect/connect.go b/internal/cmd/commands/connect/connect.go index 1519a4954f..30f498ddde 100644 --- a/internal/cmd/commands/connect/connect.go +++ b/internal/cmd/commands/connect/connect.go @@ -69,6 +69,7 @@ type Command struct { flagUsername string flagDbname string flagMongoDbAuthenticationDatabase string + flagInactiveTimeout time.Duration // HTTP httpFlags @@ -220,6 +221,13 @@ func (c *Command) Flags() *base.FlagSets { Usage: "Target scope name, if authorizing the session via scope parameters and target name. Mutually exclusive with -scope-id.", }) + f.DurationVar(&base.DurationVar{ + Name: "inactive-timeout", + Target: &c.flagInactiveTimeout, + Completion: complete.PredictAnything, + Usage: "How long to wait between connections before closing the session. Increase this value if the proxy closes during long-running processes, or use -1 to disable the timeout.", + }) + switch c.Func { case "connect": f.StringVar(&base.StringVar{ @@ -508,11 +516,32 @@ func (c *Command) Run(args []string) (retCode int) { clientProxyCloseCh := make(chan struct{}) connCountCloseCh := make(chan struct{}) + if c.flagInactiveTimeout == 0 { + // no timeout was specified by the user, so use our defaults based on subcommand + switch c.Func { + case "connect": + // connect is when there is no subcommand specified, this case should + // have the most generous timeout + apiProxyOpts = append(apiProxyOpts, apiproxy.WithInactivityTimeout(30*time.Second)) + case "rdp": + // rdp has a gui, so give the user a chance to click "reconnect" + apiProxyOpts = append(apiProxyOpts, apiproxy.WithInactivityTimeout(5*time.Second)) + case "ssh": + // one second is probably enough for ssh + apiProxyOpts = append(apiProxyOpts, apiproxy.WithInactivityTimeout(time.Second)) + default: + // for other protocols, give some extra leeway just in case + apiProxyOpts = append(apiProxyOpts, apiproxy.WithInactivityTimeout(3*time.Second)) + } + } else { + apiProxyOpts = append(apiProxyOpts, apiproxy.WithInactivityTimeout(c.flagInactiveTimeout)) + } + proxyError := new(atomic.Error) go func() { defer close(clientProxyCloseCh) - if err = clientProxy.Start(); err != nil { - c.proxyCancel() + defer c.proxyCancel() + if err = clientProxy.Start(apiProxyOpts...); err != nil { proxyError.Store(err) } }() @@ -595,10 +624,8 @@ func (c *Command) Run(args []string) (retCode int) { if c.execCmdReturnValue != nil { // Don't print out in this case, so ensure we clear it termInfo.Reason = "" - } else if time.Now().After(clientProxy.SessionExpiration()) { - termInfo.Reason = "Session has expired" - } else if clientProxy.ConnectionsLeft() == 0 { - termInfo.Reason = "No connections left in session" + } else if r := clientProxy.CloseReason(); r != "" { + termInfo.Reason = r } else if err := proxyError.Load(); err != nil { termInfo.Reason = "Error from proxy client: " + err.Error() } @@ -825,10 +852,9 @@ func (c *Command) handleExec(clientProxy *apiproxy.ClientProxy, passthroughArgs cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + cmdExit := make(chan struct{}) - if err := cmd.Run(); err != nil { - exitCode := 2 - + cmdError := func(err error) { if exitError, ok := err.(*exec.ExitError); ok { if exitError.Success() { c.execCmdReturnValue.Store(0) @@ -841,8 +867,30 @@ func (c *Command) handleExec(clientProxy *apiproxy.ClientProxy, passthroughArgs } c.PrintCliError(fmt.Errorf("Failed to run command: %w", err)) - c.execCmdReturnValue.Store(int32(exitCode)) + c.execCmdReturnValue.Store(2) return } - c.execCmdReturnValue.Store(0) + + go func() { + defer close(cmdExit) + if err := cmd.Start(); err != nil { + cmdError(err) + return + } + if err := cmd.Wait(); err != nil { + cmdError(err) + return + } + c.execCmdReturnValue.Store(0) + }() + + for { + select { + case <-c.proxyCtx.Done(): + // the proxy exited for some reason, end the cmd since connections are no longer possible + _ = endProcess(cmd.Process) + case <-cmdExit: + return + } + } } diff --git a/internal/cmd/commands/connect/end_process_nonwindows.go b/internal/cmd/commands/connect/end_process_nonwindows.go new file mode 100644 index 0000000000..bc66eb0523 --- /dev/null +++ b/internal/cmd/commands/connect/end_process_nonwindows.go @@ -0,0 +1,19 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !windows + +package connect + +import ( + "os" + "syscall" +) + +// endProcess gracefully ends the provided os process +func endProcess(p *os.Process) error { + if p == nil { + return nil + } + return p.Signal(syscall.SIGTERM) +} diff --git a/internal/cmd/commands/connect/end_process_windows.go b/internal/cmd/commands/connect/end_process_windows.go new file mode 100644 index 0000000000..e5389d3a83 --- /dev/null +++ b/internal/cmd/commands/connect/end_process_windows.go @@ -0,0 +1,18 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build windows + +package connect + +import ( + "os" +) + +// endProcess kills the provided os process +func endProcess(p *os.Process) error { + if p == nil { + return nil + } + return p.Kill() +} diff --git a/internal/cmd/commands/connect/rdp.go b/internal/cmd/commands/connect/rdp.go index 24ce167801..ffabfe4c24 100644 --- a/internal/cmd/commands/connect/rdp.go +++ b/internal/cmd/commands/connect/rdp.go @@ -59,7 +59,7 @@ func (r *rdpFlags) buildArgs(c *Command, port, ip, addr string) []string { case "mstsc.exe": args = append(args, "/v", addr) case "open": - args = append(args, "-n", "-W", fmt.Sprintf("rdp://full%saddress=s%s%s", "%20", "%3A", url.QueryEscape(addr))) + args = append(args, "-W", fmt.Sprintf("rdp://full%saddress=s%s%s", "%20", "%3A", url.QueryEscape(addr))) } return args } diff --git a/internal/tests/api/proxy/proxy_test.go b/internal/tests/api/proxy/proxy_test.go index aa78af56ef..877ccc916e 100644 --- a/internal/tests/api/proxy/proxy_test.go +++ b/internal/tests/api/proxy/proxy_test.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/boundary/internal/tests/helper" "github.com/hashicorp/go-hclog" "github.com/stretchr/testify/require" + "go.uber.org/atomic" _ "github.com/hashicorp/boundary/internal/daemon/controller/handlers/targets/tcp" ) @@ -140,3 +141,99 @@ func TestConnectionsLeft(t *testing.T) { // Wait to ensure cleanup and that the second-start logic works wg.Wait() } + +func TestConnectionTimeout(t *testing.T) { + require := require.New(t) + logger := hclog.New(&hclog.LoggerOptions{ + Name: t.Name(), + Level: hclog.Trace, + }) + + // Create controller and worker + conf, err := config.DevController() + require.NoError(err) + c1 := controller.NewTestController(t, &controller.TestControllerOpts{ + Config: conf, + InitialResourcesSuffix: "1234567890", + Logger: logger.Named("c1"), + WorkerRPCGracePeriod: helper.DefaultControllerRPCGracePeriod, + }) + helper.ExpectWorkers(t, c1) + + w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{ + WorkerAuthKms: c1.Config().WorkerAuthKms, + InitialUpstreams: c1.ClusterAddrs(), + Logger: logger.Named("w1"), + SuccessfulControllerRPCGracePeriodDuration: helper.DefaultControllerRPCGracePeriod, + Name: "w1", + }) + helper.ExpectWorkers(t, c1, w1) + + // Connect target + client := c1.Client() + client.SetToken(c1.Token().Token) + tcl := targets.NewClient(client) + tgt, err := tcl.Read(c1.Context(), "ttcp_1234567890") + require.NoError(err) + require.NotNil(tgt) + + // Create test server, update default port on target + ts := helper.NewTestTcpServer(t) + require.NotNil(t, ts) + defer ts.Close() + var sessionConnsLimit int32 = 2 + + tgt = updateTargetForProxy(t, c1.Context(), tcl, tgt, ts.Port(), sessionConnsLimit, w1.Name()) + + // Authorize session to get authorization data + sess, err := tcl.AuthorizeSession(c1.Context(), tgt.Item.Id) + require.NoError(err) + sessAuthz, err := sess.GetSessionAuthorization() + require.NoError(err) + + // Create a context we can cancel to stop the proxy, a channel for conns + // left, and a waitgroup to ensure cleanup + pxyCtx, pxyCancel := context.WithCancel(c1.Context()) + defer pxyCancel() + wg := new(sync.WaitGroup) + + pxy, err := proxy.New(pxyCtx, sessAuthz.AuthorizationToken) + require.NoError(err) + wg.Add(1) + done := atomic.NewBool(false) + go func() { + defer wg.Done() + require.NoError(pxy.Start(proxy.WithInactivityTimeout(time.Second))) + done.Store(true) + }() + + addr := pxy.ListenerAddress(context.Background()) + require.NotEmpty(addr) + addrPort, err := netip.ParseAddrPort(addr) + require.NoError(err) + + echo := []byte("echo") + readBuf := make([]byte, len(echo)) + + conn, err := net.DialTCP("tcp", nil, net.TCPAddrFromAddrPort(addrPort)) + require.NoError(err) + written, err := conn.Write(echo) + require.NoError(err) + require.Equal(written, len(echo)) + read, err := conn.Read(readBuf) + require.NoError(err) + require.Equal(read, len(echo)) + require.NoError(conn.Close()) + + start := time.Now() + for { + if done.Load() || time.Since(start) > time.Second*2 { + require.True(done.Load(), "proxy did not close itself within the expected time frame (2 seconds)") + break + } + time.Sleep(10 * time.Millisecond) + } + require.Equal("Inactivity timeout reached", pxy.CloseReason()) + pxyCancel() + wg.Wait() +}