- take out context cancellation from `listenerCloseFunc()`

- Close websocket connection which throws error when the context is canceled
pull/4389/head
Elim Tsiagbey 2 years ago
parent 91c86d3dc2
commit a3df8b1b3d

@ -35,7 +35,6 @@ type ClientProxy struct {
tofuToken string
cachedListenerAddress *ua.String
connectionsLeft *atomic.Int32
connectionsCount *atomic.Int32
connsLeftCh chan int32
callerConnectionsLeftCh chan int32
sessionAuthzData *targets.SessionAuthorizationData
@ -91,7 +90,6 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e
cachedListenerAddress: ua.NewString(""),
connsLeftCh: make(chan int32),
connectionsLeft: new(atomic.Int32),
connectionsCount: new(atomic.Int32),
listener: new(atomic.Value),
listenerCloseOnce: new(sync.Once),
connWg: new(sync.WaitGroup),
@ -128,7 +126,6 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e
}
}
p.connectionsLeft.Store(p.sessionAuthzData.ConnectionLimit)
p.connectionsCount.Store(0)
p.workerAddr = p.sessionAuthzData.WorkerInfo[0].Address
tlsConf, err := p.clientTlsConfig(opt...)
@ -193,12 +190,9 @@ func (p *ClientProxy) Start() (retErr error) {
listenerCloseFunc := func() {
p.listenerCloseOnce.Do(func() {
// Forces the for loop to exit instead of spinning on errors
p.cancel()
p.connectionsLeft.Store(0)
if err := p.listener.Load().(net.Listener).Close(); err != nil {
if !errors.Is(err, net.ErrClosed) {
retErr = errors.Join(retErr, fmt.Errorf("error closing proxy listener: %w", err))
}
if err := p.listener.Load().(net.Listener).Close(); err != nil && err != net.ErrClosed {
retErr = errors.Join(retErr, fmt.Errorf("error closing proxy listener: %w", err))
}
})
}
@ -233,18 +227,13 @@ func (p *ClientProxy) Start() (retErr error) {
// connection that comes our way, so cancel the proxy
fin <- fmt.Errorf("error from accept: %w", err)
listenerCloseFunc()
p.cancel()
return
}
}
p.connWg.Add(1)
p.connectionsCount.Add(1)
go func() {
defer listeningConn.Close()
defer func() {
p.connectionsCount.Add(-1)
p.connsLeftCh <- p.connectionsLeft.Load()
}()
defer p.connWg.Done()
wsConn, err := p.getWsConn(p.ctx)
if err != nil {
@ -252,6 +241,7 @@ func (p *ClientProxy) Start() (retErr error) {
// No reason to think we can successfully handle the next
// connection that comes our way, so cancel the proxy
listenerCloseFunc()
p.cancel()
return
}
if err := p.runTcpProxyV1(wsConn, listeningConn); err != nil {
@ -259,6 +249,7 @@ func (p *ClientProxy) Start() (retErr error) {
// No reason to think we can successfully handle the next
// connection that comes our way, so cancel the proxy
listenerCloseFunc()
p.cancel()
return
}
}()
@ -288,21 +279,12 @@ func (p *ClientProxy) Start() (retErr error) {
return
case connsLeft := <-p.connsLeftCh:
p.connectionsLeft.Store(connsLeft)
if p.callerConnectionsLeftCh != nil && p.sessionAuthzData.ConnectionLimit != -1 {
if p.callerConnectionsLeftCh != nil {
p.callerConnectionsLeftCh <- connsLeft
}
// If there are no connections left, close the listener
// to stop new connections from being accepted
if connsLeft == 0 {
if err := p.listener.Load().(net.Listener).Close(); err != nil {
if !errors.Is(err, net.ErrClosed) {
retErr = errors.Join(retErr, fmt.Errorf("error closing proxy listener: %w", err))
}
}
}
if connsLeft == 0 && p.ConnectionsCount() == 0 {
// Close the listener as we can't authorize any more
// connections
return
}
}
@ -310,6 +292,7 @@ func (p *ClientProxy) Start() (retErr error) {
}()
p.connWg.Wait()
defer p.cancel()
{
// the go funcs are done, so we can safely close the chan and range over any errors
@ -408,11 +391,3 @@ func (p *ClientProxy) SessionExpiration() time.Time {
func (p *ClientProxy) ConnectionsLeft() int32 {
return p.connectionsLeft.Load()
}
// ConnectionsCount returns the number of connections in the session.
//
// EXPERIMENTAL: While this API is not expected to change, it is new and
// feedback from users may necessitate changes.
func (p *ClientProxy) ConnectionsCount() int32 {
return p.connectionsCount.Load()
}

@ -66,6 +66,7 @@ func (p *ClientProxy) sendSessionTeardown(ctx context.Context) error {
if err := wspb.Write(ctx, wsConn, &handshake); err != nil {
return fmt.Errorf("error sending teardown handshake to worker: %w", err)
}
wsConn.Close(websocket.StatusNormalClosure, "session teardown finished")
return nil
}

@ -45,8 +45,7 @@ type SessionInfo struct {
}
type ConnectionInfo struct {
ConnectionsLeft int32 `json:"connections_left"`
ConnectionsCount int32 `json:"connections_count"`
ConnectionsLeft int32 `json:"connections_left"`
}
type TerminationInfo struct {
@ -453,12 +452,8 @@ func (c *Command) Run(args []string) (retCode int) {
// done it manually
return
case connsLeft := <-connsLeftCh:
connectionsCount := clientProxy.ConnectionsCount()
c.updateConnsLeft(connsLeft, connectionsCount)
// If there are no available connections left and there are no connections
// we can exit
if connsLeft == 0 && connectionsCount == 0 {
c.updateConnsLeft(connsLeft)
if connsLeft == 0 {
return
}
}
@ -576,10 +571,9 @@ func (c *Command) printCredentials(creds []*targets.SessionCredential) error {
return nil
}
func (c *Command) updateConnsLeft(connsLeft int32, connsCount int32) {
func (c *Command) updateConnsLeft(connsLeft int32) {
connInfo := ConnectionInfo{
ConnectionsLeft: connsLeft,
ConnectionsCount: connsCount,
ConnectionsLeft: connsLeft,
}
if c.flagExec == "" {

@ -114,8 +114,7 @@ func generateConnectionInfoTableOutput(in ConnectionInfo) string {
var ret []string
nonAttributeMap := map[string]any{
"Connections Left": in.ConnectionsLeft,
"Connections Count": in.ConnectionsCount,
"Connections Left": in.ConnectionsLeft,
}
maxLength := 0

@ -109,7 +109,6 @@ func TestConnectionsLeft(t *testing.T) {
// While we have sessions left, expect no error and to read and write
// through the proxy, and to read the conns left from the channel. Once we
// have hit the connection limit, we expect an error on dial.
connectionCount := int32(0)
for i := sessionConnsLimit; i >= 0; i-- {
// Give time for the listener to be closed. The information about conns
// left comes from upstream responses and we can circle around too fast
@ -121,7 +120,6 @@ func TestConnectionsLeft(t *testing.T) {
break
}
require.NoError(err)
written, err := conn.Write(echo)
require.NoError(err)
require.Equal(written, len(echo))
@ -131,15 +129,9 @@ func TestConnectionsLeft(t *testing.T) {
connsLeft := <-connsLeftCh
require.Equal(i-1, connsLeft)
require.Equal(i-1, pxy.ConnectionsLeft())
connectionCount++
require.Equal(connectionCount, pxy.ConnectionsCount())
}
pxyCancel()
// Wait to ensure cleanup and that the second-start logic works
wg.Wait()
require.Equal(int32(0), pxy.ConnectionsLeft())
require.Equal(int32(0), pxy.ConnectionsCount())
}

Loading…
Cancel
Save