From a3df8b1b3db05f035c998df8be4f4575f9bd6daa Mon Sep 17 00:00:00 2001 From: Elim Tsiagbey Date: Fri, 9 Feb 2024 17:05:52 -0500 Subject: [PATCH] - take out context cancellation from `listenerCloseFunc()` - Close websocket connection which throws error when the context is canceled --- api/proxy/proxy.go | 43 +++++------------------- api/proxy/websocket.go | 1 + internal/cmd/commands/connect/connect.go | 16 +++------ internal/cmd/commands/connect/funcs.go | 3 +- internal/tests/api/proxy/proxy_test.go | 8 ----- 5 files changed, 16 insertions(+), 55 deletions(-) diff --git a/api/proxy/proxy.go b/api/proxy/proxy.go index 8a551ae14d..5375fa3c74 100644 --- a/api/proxy/proxy.go +++ b/api/proxy/proxy.go @@ -35,7 +35,6 @@ type ClientProxy struct { tofuToken string cachedListenerAddress *ua.String connectionsLeft *atomic.Int32 - connectionsCount *atomic.Int32 connsLeftCh chan int32 callerConnectionsLeftCh chan int32 sessionAuthzData *targets.SessionAuthorizationData @@ -91,7 +90,6 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e cachedListenerAddress: ua.NewString(""), connsLeftCh: make(chan int32), connectionsLeft: new(atomic.Int32), - connectionsCount: new(atomic.Int32), listener: new(atomic.Value), listenerCloseOnce: new(sync.Once), connWg: new(sync.WaitGroup), @@ -128,7 +126,6 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e } } p.connectionsLeft.Store(p.sessionAuthzData.ConnectionLimit) - p.connectionsCount.Store(0) p.workerAddr = p.sessionAuthzData.WorkerInfo[0].Address tlsConf, err := p.clientTlsConfig(opt...) @@ -193,12 +190,9 @@ func (p *ClientProxy) Start() (retErr error) { listenerCloseFunc := func() { p.listenerCloseOnce.Do(func() { // Forces the for loop to exit instead of spinning on errors - p.cancel() p.connectionsLeft.Store(0) - if err := p.listener.Load().(net.Listener).Close(); err != nil { - if !errors.Is(err, net.ErrClosed) { - retErr = errors.Join(retErr, fmt.Errorf("error closing proxy listener: %w", err)) - } + if err := p.listener.Load().(net.Listener).Close(); err != nil && err != net.ErrClosed { + retErr = errors.Join(retErr, fmt.Errorf("error closing proxy listener: %w", err)) } }) } @@ -233,18 +227,13 @@ func (p *ClientProxy) Start() (retErr error) { // connection that comes our way, so cancel the proxy fin <- fmt.Errorf("error from accept: %w", err) listenerCloseFunc() + p.cancel() return } } - p.connWg.Add(1) - p.connectionsCount.Add(1) go func() { defer listeningConn.Close() - defer func() { - p.connectionsCount.Add(-1) - p.connsLeftCh <- p.connectionsLeft.Load() - }() defer p.connWg.Done() wsConn, err := p.getWsConn(p.ctx) if err != nil { @@ -252,6 +241,7 @@ func (p *ClientProxy) Start() (retErr error) { // No reason to think we can successfully handle the next // connection that comes our way, so cancel the proxy listenerCloseFunc() + p.cancel() return } if err := p.runTcpProxyV1(wsConn, listeningConn); err != nil { @@ -259,6 +249,7 @@ func (p *ClientProxy) Start() (retErr error) { // No reason to think we can successfully handle the next // connection that comes our way, so cancel the proxy listenerCloseFunc() + p.cancel() return } }() @@ -288,21 +279,12 @@ func (p *ClientProxy) Start() (retErr error) { return case connsLeft := <-p.connsLeftCh: p.connectionsLeft.Store(connsLeft) - if p.callerConnectionsLeftCh != nil && p.sessionAuthzData.ConnectionLimit != -1 { + if p.callerConnectionsLeftCh != nil { p.callerConnectionsLeftCh <- connsLeft } - - // If there are no connections left, close the listener - // to stop new connections from being accepted if connsLeft == 0 { - if err := p.listener.Load().(net.Listener).Close(); err != nil { - if !errors.Is(err, net.ErrClosed) { - retErr = errors.Join(retErr, fmt.Errorf("error closing proxy listener: %w", err)) - } - } - } - - if connsLeft == 0 && p.ConnectionsCount() == 0 { + // Close the listener as we can't authorize any more + // connections return } } @@ -310,6 +292,7 @@ func (p *ClientProxy) Start() (retErr error) { }() p.connWg.Wait() + defer p.cancel() { // the go funcs are done, so we can safely close the chan and range over any errors @@ -408,11 +391,3 @@ func (p *ClientProxy) SessionExpiration() time.Time { func (p *ClientProxy) ConnectionsLeft() int32 { return p.connectionsLeft.Load() } - -// ConnectionsCount returns the number of connections in the session. -// -// EXPERIMENTAL: While this API is not expected to change, it is new and -// feedback from users may necessitate changes. -func (p *ClientProxy) ConnectionsCount() int32 { - return p.connectionsCount.Load() -} diff --git a/api/proxy/websocket.go b/api/proxy/websocket.go index 0774b51470..018e13e430 100644 --- a/api/proxy/websocket.go +++ b/api/proxy/websocket.go @@ -66,6 +66,7 @@ func (p *ClientProxy) sendSessionTeardown(ctx context.Context) error { if err := wspb.Write(ctx, wsConn, &handshake); err != nil { return fmt.Errorf("error sending teardown handshake to worker: %w", err) } + wsConn.Close(websocket.StatusNormalClosure, "session teardown finished") return nil } diff --git a/internal/cmd/commands/connect/connect.go b/internal/cmd/commands/connect/connect.go index c80f528923..e840506bab 100644 --- a/internal/cmd/commands/connect/connect.go +++ b/internal/cmd/commands/connect/connect.go @@ -45,8 +45,7 @@ type SessionInfo struct { } type ConnectionInfo struct { - ConnectionsLeft int32 `json:"connections_left"` - ConnectionsCount int32 `json:"connections_count"` + ConnectionsLeft int32 `json:"connections_left"` } type TerminationInfo struct { @@ -453,12 +452,8 @@ func (c *Command) Run(args []string) (retCode int) { // done it manually return case connsLeft := <-connsLeftCh: - connectionsCount := clientProxy.ConnectionsCount() - c.updateConnsLeft(connsLeft, connectionsCount) - - // If there are no available connections left and there are no connections - // we can exit - if connsLeft == 0 && connectionsCount == 0 { + c.updateConnsLeft(connsLeft) + if connsLeft == 0 { return } } @@ -576,10 +571,9 @@ func (c *Command) printCredentials(creds []*targets.SessionCredential) error { return nil } -func (c *Command) updateConnsLeft(connsLeft int32, connsCount int32) { +func (c *Command) updateConnsLeft(connsLeft int32) { connInfo := ConnectionInfo{ - ConnectionsLeft: connsLeft, - ConnectionsCount: connsCount, + ConnectionsLeft: connsLeft, } if c.flagExec == "" { diff --git a/internal/cmd/commands/connect/funcs.go b/internal/cmd/commands/connect/funcs.go index 91782c48bc..da115b8048 100644 --- a/internal/cmd/commands/connect/funcs.go +++ b/internal/cmd/commands/connect/funcs.go @@ -114,8 +114,7 @@ func generateConnectionInfoTableOutput(in ConnectionInfo) string { var ret []string nonAttributeMap := map[string]any{ - "Connections Left": in.ConnectionsLeft, - "Connections Count": in.ConnectionsCount, + "Connections Left": in.ConnectionsLeft, } maxLength := 0 diff --git a/internal/tests/api/proxy/proxy_test.go b/internal/tests/api/proxy/proxy_test.go index 1714a989aa..461bf248b8 100644 --- a/internal/tests/api/proxy/proxy_test.go +++ b/internal/tests/api/proxy/proxy_test.go @@ -109,7 +109,6 @@ func TestConnectionsLeft(t *testing.T) { // While we have sessions left, expect no error and to read and write // through the proxy, and to read the conns left from the channel. Once we // have hit the connection limit, we expect an error on dial. - connectionCount := int32(0) for i := sessionConnsLimit; i >= 0; i-- { // Give time for the listener to be closed. The information about conns // left comes from upstream responses and we can circle around too fast @@ -121,7 +120,6 @@ func TestConnectionsLeft(t *testing.T) { break } require.NoError(err) - written, err := conn.Write(echo) require.NoError(err) require.Equal(written, len(echo)) @@ -131,15 +129,9 @@ func TestConnectionsLeft(t *testing.T) { connsLeft := <-connsLeftCh require.Equal(i-1, connsLeft) require.Equal(i-1, pxy.ConnectionsLeft()) - connectionCount++ - require.Equal(connectionCount, pxy.ConnectionsCount()) - } pxyCancel() // Wait to ensure cleanup and that the second-start logic works wg.Wait() - - require.Equal(int32(0), pxy.ConnectionsLeft()) - require.Equal(int32(0), pxy.ConnectionsCount()) }