From 17f39ff35722c4cf2495fe8779dfa5ea5ab7f773 Mon Sep 17 00:00:00 2001 From: Elim Tsiagbey Date: Fri, 9 Feb 2024 16:49:25 +0000 Subject: [PATCH] backport of commit 8dc8263e219d330266e10ffda57f367527a27a3c --- api/proxy/option.go | 13 +++ api/proxy/option_test.go | 12 +++ api/proxy/proxy.go | 101 +++++++++++++++-------- internal/cmd/commands/connect/connect.go | 19 +++-- internal/cmd/commands/connect/funcs.go | 3 +- internal/tests/api/proxy/proxy_test.go | 14 +++- 6 files changed, 121 insertions(+), 41 deletions(-) diff --git a/api/proxy/option.go b/api/proxy/option.go index c22f1961e5..75f6388c3a 100644 --- a/api/proxy/option.go +++ b/api/proxy/option.go @@ -32,6 +32,7 @@ type Options struct { WithListener net.Listener WithListenAddrPort netip.AddrPort WithConnectionsLeftCh chan int32 + WithConnectionsCountCh chan int32 WithWorkerHost string WithSessionAuthorizationData *targets.SessionAuthorizationData WithSkipSessionTeardown bool @@ -118,3 +119,15 @@ func WithSkipSessionTeardown(with bool) Option { return nil } } + +// WithConnectionsCountCh allows providing a channel to receive updates about the count of connections. +// It is the caller's responsibility to ensure that this is drained and does not block. +func WithConnectionsCountCh(with chan int32) Option { + return func(o *Options) error { + if with == nil { + return errors.New("channel passed to WithConnectionsCountCh is nil") + } + o.WithConnectionsCountCh = with + return nil + } +} diff --git a/api/proxy/option_test.go b/api/proxy/option_test.go index f982c2887d..3d721d2c4d 100644 --- a/api/proxy/option_test.go +++ b/api/proxy/option_test.go @@ -86,4 +86,16 @@ func Test_GetOpts(t *testing.T) { require.NoError(t, err) assert.True(opts.WithSkipSessionTeardown) }) + t.Run("with-connections-count-ch", func(t *testing.T) { + assert := assert.New(t) + opts, err := getOpts() + require.NoError(t, err) + assert.Nil(opts.WithConnectionsCountCh) + _, err = getOpts(WithConnectionsCountCh(nil)) + require.Error(t, err) + l := make(chan int32) + opts, err = getOpts(WithConnectionsCountCh(l)) + require.NoError(t, err) + assert.Equal(l, opts.WithConnectionsCountCh) + }) } diff --git a/api/proxy/proxy.go b/api/proxy/proxy.go index a565d12b85..7970a5cd2c 100644 --- a/api/proxy/proxy.go +++ b/api/proxy/proxy.go @@ -32,25 +32,28 @@ import ( const sessionCancelTimeout = 30 * time.Second type ClientProxy struct { - tofuToken string - cachedListenerAddress *ua.String - connectionsLeft *atomic.Int32 - connsLeftCh chan int32 - callerConnectionsLeftCh chan int32 - sessionAuthzData *targets.SessionAuthorizationData - createTime time.Time - expiration time.Time - ctx context.Context - cancel context.CancelFunc - transport *http.Transport - workerAddr string - listenAddrPort netip.AddrPort - listener *atomic.Value - listenerCloseOnce *sync.Once - clientTlsConf *tls.Config - connWg *sync.WaitGroup - started *atomic.Bool - skipSessionTeardown bool + tofuToken string + cachedListenerAddress *ua.String + connectionsLeft *atomic.Int32 + connectionsCount *atomic.Int32 + connsLeftCh chan int32 + connsCountCh chan int32 + callerConnectionsLeftCh chan int32 + callerConnectionsCountCh chan int32 + sessionAuthzData *targets.SessionAuthorizationData + createTime time.Time + expiration time.Time + ctx context.Context + cancel context.CancelFunc + transport *http.Transport + workerAddr string + listenAddrPort netip.AddrPort + listener *atomic.Value + listenerCloseOnce *sync.Once + clientTlsConf *tls.Config + connWg *sync.WaitGroup + started *atomic.Bool + skipSessionTeardown bool } // New creates a new client proxy. The given context should be cancelable; once @@ -87,16 +90,19 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e } p := &ClientProxy{ - cachedListenerAddress: ua.NewString(""), - connsLeftCh: make(chan int32), - connectionsLeft: new(atomic.Int32), - listener: new(atomic.Value), - listenerCloseOnce: new(sync.Once), - connWg: new(sync.WaitGroup), - listenAddrPort: opts.WithListenAddrPort, - callerConnectionsLeftCh: opts.WithConnectionsLeftCh, - started: new(atomic.Bool), - skipSessionTeardown: opts.WithSkipSessionTeardown, + cachedListenerAddress: ua.NewString(""), + connsLeftCh: make(chan int32), + connsCountCh: make(chan int32), + connectionsLeft: new(atomic.Int32), + connectionsCount: new(atomic.Int32), + listener: new(atomic.Value), + listenerCloseOnce: new(sync.Once), + connWg: new(sync.WaitGroup), + listenAddrPort: opts.WithListenAddrPort, + callerConnectionsLeftCh: opts.WithConnectionsLeftCh, + callerConnectionsCountCh: opts.WithConnectionsCountCh, + started: new(atomic.Bool), + skipSessionTeardown: opts.WithSkipSessionTeardown, } if opts.WithListener != nil { @@ -126,6 +132,7 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e } } p.connectionsLeft.Store(p.sessionAuthzData.ConnectionLimit) + p.connectionsCount.Store(0) p.workerAddr = p.sessionAuthzData.WorkerInfo[0].Address tlsConf, err := p.clientTlsConfig(opt...) @@ -192,8 +199,10 @@ func (p *ClientProxy) Start() (retErr error) { // Forces the for loop to exit instead of spinning on errors p.cancel() p.connectionsLeft.Store(0) - if err := p.listener.Load().(net.Listener).Close(); err != nil && err != net.ErrClosed { - retErr = errors.Join(retErr, fmt.Errorf("error closing proxy listener: %w", err)) + if err := p.listener.Load().(net.Listener).Close(); err != nil { + if !errors.Is(err, net.ErrClosed) { + retErr = errors.Join(retErr, fmt.Errorf("error closing proxy listener: %w", err)) + } } }) } @@ -231,9 +240,14 @@ func (p *ClientProxy) Start() (retErr error) { return } } + p.connWg.Add(1) + p.connsCountCh <- p.connectionsCount.Add(1) go func() { defer listeningConn.Close() + defer func() { + p.connsCountCh <- p.connectionsCount.Add(-1) + }() defer p.connWg.Done() wsConn, err := p.getWsConn(p.ctx) if err != nil { @@ -280,9 +294,22 @@ func (p *ClientProxy) Start() (retErr error) { if p.callerConnectionsLeftCh != nil { p.callerConnectionsLeftCh <- connsLeft } + // If there are no connections left, close the listener + // to stop new connections from being accepted if connsLeft == 0 { - // Close the listener as we can't authorize any more - // connections + if err := p.listener.Load().(net.Listener).Close(); err != nil { + if !errors.Is(err, net.ErrClosed) { + retErr = errors.Join(retErr, fmt.Errorf("error closing proxy listener: %w", err)) + } + } + } + case connsCount := <-p.connsCountCh: + if p.callerConnectionsCountCh != nil { + p.callerConnectionsCountCh <- connsCount + } + // If there are no connections and no connections left, + // we can exit + if p.ConnectionsLeft() == 0 && connsCount == 0 { return } } @@ -388,3 +415,11 @@ func (p *ClientProxy) SessionExpiration() time.Time { func (p *ClientProxy) ConnectionsLeft() int32 { return p.connectionsLeft.Load() } + +// ConnectionsCount returns the number of connections in the session. +// +// EXPERIMENTAL: While this API is not expected to change, it is new and +// feedback from users may necessitate changes. +func (p *ClientProxy) ConnectionsCount() int32 { + return p.connectionsCount.Load() +} diff --git a/internal/cmd/commands/connect/connect.go b/internal/cmd/commands/connect/connect.go index e840506bab..ecf10dba14 100644 --- a/internal/cmd/commands/connect/connect.go +++ b/internal/cmd/commands/connect/connect.go @@ -45,7 +45,8 @@ type SessionInfo struct { } type ConnectionInfo struct { - ConnectionsLeft int32 `json:"connections_left"` + ConnectionsLeft int32 `json:"connections_left"` + ConnectionsCount int32 `json:"connections_count"` } type TerminationInfo struct { @@ -420,7 +421,8 @@ func (c *Command) Run(args []string) (retCode int) { listenAddr = netip.AddrPortFrom(addr, uint16(c.flagListenPort)) connsLeftCh := make(chan int32) - apiProxyOpts := []apiproxy.Option{apiproxy.WithConnectionsLeftCh(connsLeftCh)} + connsCountCh := make(chan int32) + apiProxyOpts := []apiproxy.Option{apiproxy.WithConnectionsLeftCh(connsLeftCh), apiproxy.WithConnectionsCountCh(connsCountCh)} if listenAddr.IsValid() { apiProxyOpts = append(apiProxyOpts, apiproxy.WithListenAddrPort(listenAddr)) } @@ -452,8 +454,12 @@ func (c *Command) Run(args []string) (retCode int) { // done it manually return case connsLeft := <-connsLeftCh: - c.updateConnsLeft(connsLeft) - if connsLeft == 0 { + c.updateConnsLeft(connsLeft, clientProxy.ConnectionsCount()) + case connsCount := <-connsCountCh: + c.updateConnsLeft(clientProxy.ConnectionsLeft(), connsCount) + // If there are no counts left and there are no connections + // we can exit + if clientProxy.ConnectionsLeft() == 0 && connsCount == 0 { return } } @@ -571,9 +577,10 @@ func (c *Command) printCredentials(creds []*targets.SessionCredential) error { return nil } -func (c *Command) updateConnsLeft(connsLeft int32) { +func (c *Command) updateConnsLeft(connsLeft int32, connsCount int32) { connInfo := ConnectionInfo{ - ConnectionsLeft: connsLeft, + ConnectionsLeft: connsLeft, + ConnectionsCount: connsCount, } if c.flagExec == "" { diff --git a/internal/cmd/commands/connect/funcs.go b/internal/cmd/commands/connect/funcs.go index da115b8048..91782c48bc 100644 --- a/internal/cmd/commands/connect/funcs.go +++ b/internal/cmd/commands/connect/funcs.go @@ -114,7 +114,8 @@ func generateConnectionInfoTableOutput(in ConnectionInfo) string { var ret []string nonAttributeMap := map[string]any{ - "Connections Left": in.ConnectionsLeft, + "Connections Left": in.ConnectionsLeft, + "Connections Count": in.ConnectionsCount, } maxLength := 0 diff --git a/internal/tests/api/proxy/proxy_test.go b/internal/tests/api/proxy/proxy_test.go index 461bf248b8..7c89c2c2df 100644 --- a/internal/tests/api/proxy/proxy_test.go +++ b/internal/tests/api/proxy/proxy_test.go @@ -86,9 +86,10 @@ func TestConnectionsLeft(t *testing.T) { pxyCtx, pxyCancel := context.WithCancel(c1.Context()) defer pxyCancel() connsLeftCh := make(chan int32) + connsCountCh := make(chan int32) wg := new(sync.WaitGroup) - pxy, err := proxy.New(pxyCtx, sessAuthz.AuthorizationToken, proxy.WithConnectionsLeftCh(connsLeftCh)) + pxy, err := proxy.New(pxyCtx, sessAuthz.AuthorizationToken, proxy.WithConnectionsLeftCh(connsLeftCh), proxy.WithConnectionsCountCh(connsCountCh)) require.NoError(err) wg.Add(1) go func() { @@ -109,6 +110,7 @@ func TestConnectionsLeft(t *testing.T) { // While we have sessions left, expect no error and to read and write // through the proxy, and to read the conns left from the channel. Once we // have hit the connection limit, we expect an error on dial. + connectionCount := int32(0) for i := sessionConnsLimit; i >= 0; i-- { // Give time for the listener to be closed. The information about conns // left comes from upstream responses and we can circle around too fast @@ -120,6 +122,12 @@ func TestConnectionsLeft(t *testing.T) { break } require.NoError(err) + + connsCount := <-connsCountCh + connectionCount++ + require.Equal(connectionCount, connsCount) + require.Equal(connectionCount, pxy.ConnectionsCount()) + written, err := conn.Write(echo) require.NoError(err) require.Equal(written, len(echo)) @@ -129,9 +137,13 @@ func TestConnectionsLeft(t *testing.T) { connsLeft := <-connsLeftCh require.Equal(i-1, connsLeft) require.Equal(i-1, pxy.ConnectionsLeft()) + } pxyCancel() // Wait to ensure cleanup and that the second-start logic works wg.Wait() + + require.Equal(int32(0), pxy.ConnectionsLeft()) + require.Equal(int32(0), pxy.ConnectionsCount()) }