backport of commit 8dc8263e21

pull/4398/head
Elim Tsiagbey 2 years ago
parent 806214d1f1
commit 17f39ff357

@ -32,6 +32,7 @@ type Options struct {
WithListener net.Listener
WithListenAddrPort netip.AddrPort
WithConnectionsLeftCh chan int32
WithConnectionsCountCh chan int32
WithWorkerHost string
WithSessionAuthorizationData *targets.SessionAuthorizationData
WithSkipSessionTeardown bool
@ -118,3 +119,15 @@ func WithSkipSessionTeardown(with bool) Option {
return nil
}
}
// WithConnectionsCountCh allows providing a channel to receive updates about the count of connections.
// It is the caller's responsibility to ensure that this is drained and does not block.
func WithConnectionsCountCh(with chan int32) Option {
return func(o *Options) error {
if with == nil {
return errors.New("channel passed to WithConnectionsCountCh is nil")
}
o.WithConnectionsCountCh = with
return nil
}
}

@ -86,4 +86,16 @@ func Test_GetOpts(t *testing.T) {
require.NoError(t, err)
assert.True(opts.WithSkipSessionTeardown)
})
t.Run("with-connections-count-ch", func(t *testing.T) {
assert := assert.New(t)
opts, err := getOpts()
require.NoError(t, err)
assert.Nil(opts.WithConnectionsCountCh)
_, err = getOpts(WithConnectionsCountCh(nil))
require.Error(t, err)
l := make(chan int32)
opts, err = getOpts(WithConnectionsCountCh(l))
require.NoError(t, err)
assert.Equal(l, opts.WithConnectionsCountCh)
})
}

@ -32,25 +32,28 @@ import (
const sessionCancelTimeout = 30 * time.Second
type ClientProxy struct {
tofuToken string
cachedListenerAddress *ua.String
connectionsLeft *atomic.Int32
connsLeftCh chan int32
callerConnectionsLeftCh chan int32
sessionAuthzData *targets.SessionAuthorizationData
createTime time.Time
expiration time.Time
ctx context.Context
cancel context.CancelFunc
transport *http.Transport
workerAddr string
listenAddrPort netip.AddrPort
listener *atomic.Value
listenerCloseOnce *sync.Once
clientTlsConf *tls.Config
connWg *sync.WaitGroup
started *atomic.Bool
skipSessionTeardown bool
tofuToken string
cachedListenerAddress *ua.String
connectionsLeft *atomic.Int32
connectionsCount *atomic.Int32
connsLeftCh chan int32
connsCountCh chan int32
callerConnectionsLeftCh chan int32
callerConnectionsCountCh chan int32
sessionAuthzData *targets.SessionAuthorizationData
createTime time.Time
expiration time.Time
ctx context.Context
cancel context.CancelFunc
transport *http.Transport
workerAddr string
listenAddrPort netip.AddrPort
listener *atomic.Value
listenerCloseOnce *sync.Once
clientTlsConf *tls.Config
connWg *sync.WaitGroup
started *atomic.Bool
skipSessionTeardown bool
}
// New creates a new client proxy. The given context should be cancelable; once
@ -87,16 +90,19 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e
}
p := &ClientProxy{
cachedListenerAddress: ua.NewString(""),
connsLeftCh: make(chan int32),
connectionsLeft: new(atomic.Int32),
listener: new(atomic.Value),
listenerCloseOnce: new(sync.Once),
connWg: new(sync.WaitGroup),
listenAddrPort: opts.WithListenAddrPort,
callerConnectionsLeftCh: opts.WithConnectionsLeftCh,
started: new(atomic.Bool),
skipSessionTeardown: opts.WithSkipSessionTeardown,
cachedListenerAddress: ua.NewString(""),
connsLeftCh: make(chan int32),
connsCountCh: make(chan int32),
connectionsLeft: new(atomic.Int32),
connectionsCount: new(atomic.Int32),
listener: new(atomic.Value),
listenerCloseOnce: new(sync.Once),
connWg: new(sync.WaitGroup),
listenAddrPort: opts.WithListenAddrPort,
callerConnectionsLeftCh: opts.WithConnectionsLeftCh,
callerConnectionsCountCh: opts.WithConnectionsCountCh,
started: new(atomic.Bool),
skipSessionTeardown: opts.WithSkipSessionTeardown,
}
if opts.WithListener != nil {
@ -126,6 +132,7 @@ func New(ctx context.Context, authzToken string, opt ...Option) (*ClientProxy, e
}
}
p.connectionsLeft.Store(p.sessionAuthzData.ConnectionLimit)
p.connectionsCount.Store(0)
p.workerAddr = p.sessionAuthzData.WorkerInfo[0].Address
tlsConf, err := p.clientTlsConfig(opt...)
@ -192,8 +199,10 @@ func (p *ClientProxy) Start() (retErr error) {
// Forces the for loop to exit instead of spinning on errors
p.cancel()
p.connectionsLeft.Store(0)
if err := p.listener.Load().(net.Listener).Close(); err != nil && err != net.ErrClosed {
retErr = errors.Join(retErr, fmt.Errorf("error closing proxy listener: %w", err))
if err := p.listener.Load().(net.Listener).Close(); err != nil {
if !errors.Is(err, net.ErrClosed) {
retErr = errors.Join(retErr, fmt.Errorf("error closing proxy listener: %w", err))
}
}
})
}
@ -231,9 +240,14 @@ func (p *ClientProxy) Start() (retErr error) {
return
}
}
p.connWg.Add(1)
p.connsCountCh <- p.connectionsCount.Add(1)
go func() {
defer listeningConn.Close()
defer func() {
p.connsCountCh <- p.connectionsCount.Add(-1)
}()
defer p.connWg.Done()
wsConn, err := p.getWsConn(p.ctx)
if err != nil {
@ -280,9 +294,22 @@ func (p *ClientProxy) Start() (retErr error) {
if p.callerConnectionsLeftCh != nil {
p.callerConnectionsLeftCh <- connsLeft
}
// If there are no connections left, close the listener
// to stop new connections from being accepted
if connsLeft == 0 {
// Close the listener as we can't authorize any more
// connections
if err := p.listener.Load().(net.Listener).Close(); err != nil {
if !errors.Is(err, net.ErrClosed) {
retErr = errors.Join(retErr, fmt.Errorf("error closing proxy listener: %w", err))
}
}
}
case connsCount := <-p.connsCountCh:
if p.callerConnectionsCountCh != nil {
p.callerConnectionsCountCh <- connsCount
}
// If there are no connections and no connections left,
// we can exit
if p.ConnectionsLeft() == 0 && connsCount == 0 {
return
}
}
@ -388,3 +415,11 @@ func (p *ClientProxy) SessionExpiration() time.Time {
func (p *ClientProxy) ConnectionsLeft() int32 {
return p.connectionsLeft.Load()
}
// ConnectionsCount returns the number of connections in the session.
//
// EXPERIMENTAL: While this API is not expected to change, it is new and
// feedback from users may necessitate changes.
func (p *ClientProxy) ConnectionsCount() int32 {
return p.connectionsCount.Load()
}

@ -45,7 +45,8 @@ type SessionInfo struct {
}
type ConnectionInfo struct {
ConnectionsLeft int32 `json:"connections_left"`
ConnectionsLeft int32 `json:"connections_left"`
ConnectionsCount int32 `json:"connections_count"`
}
type TerminationInfo struct {
@ -420,7 +421,8 @@ func (c *Command) Run(args []string) (retCode int) {
listenAddr = netip.AddrPortFrom(addr, uint16(c.flagListenPort))
connsLeftCh := make(chan int32)
apiProxyOpts := []apiproxy.Option{apiproxy.WithConnectionsLeftCh(connsLeftCh)}
connsCountCh := make(chan int32)
apiProxyOpts := []apiproxy.Option{apiproxy.WithConnectionsLeftCh(connsLeftCh), apiproxy.WithConnectionsCountCh(connsCountCh)}
if listenAddr.IsValid() {
apiProxyOpts = append(apiProxyOpts, apiproxy.WithListenAddrPort(listenAddr))
}
@ -452,8 +454,12 @@ func (c *Command) Run(args []string) (retCode int) {
// done it manually
return
case connsLeft := <-connsLeftCh:
c.updateConnsLeft(connsLeft)
if connsLeft == 0 {
c.updateConnsLeft(connsLeft, clientProxy.ConnectionsCount())
case connsCount := <-connsCountCh:
c.updateConnsLeft(clientProxy.ConnectionsLeft(), connsCount)
// If there are no counts left and there are no connections
// we can exit
if clientProxy.ConnectionsLeft() == 0 && connsCount == 0 {
return
}
}
@ -571,9 +577,10 @@ func (c *Command) printCredentials(creds []*targets.SessionCredential) error {
return nil
}
func (c *Command) updateConnsLeft(connsLeft int32) {
func (c *Command) updateConnsLeft(connsLeft int32, connsCount int32) {
connInfo := ConnectionInfo{
ConnectionsLeft: connsLeft,
ConnectionsLeft: connsLeft,
ConnectionsCount: connsCount,
}
if c.flagExec == "" {

@ -114,7 +114,8 @@ func generateConnectionInfoTableOutput(in ConnectionInfo) string {
var ret []string
nonAttributeMap := map[string]any{
"Connections Left": in.ConnectionsLeft,
"Connections Left": in.ConnectionsLeft,
"Connections Count": in.ConnectionsCount,
}
maxLength := 0

@ -86,9 +86,10 @@ func TestConnectionsLeft(t *testing.T) {
pxyCtx, pxyCancel := context.WithCancel(c1.Context())
defer pxyCancel()
connsLeftCh := make(chan int32)
connsCountCh := make(chan int32)
wg := new(sync.WaitGroup)
pxy, err := proxy.New(pxyCtx, sessAuthz.AuthorizationToken, proxy.WithConnectionsLeftCh(connsLeftCh))
pxy, err := proxy.New(pxyCtx, sessAuthz.AuthorizationToken, proxy.WithConnectionsLeftCh(connsLeftCh), proxy.WithConnectionsCountCh(connsCountCh))
require.NoError(err)
wg.Add(1)
go func() {
@ -109,6 +110,7 @@ func TestConnectionsLeft(t *testing.T) {
// While we have sessions left, expect no error and to read and write
// through the proxy, and to read the conns left from the channel. Once we
// have hit the connection limit, we expect an error on dial.
connectionCount := int32(0)
for i := sessionConnsLimit; i >= 0; i-- {
// Give time for the listener to be closed. The information about conns
// left comes from upstream responses and we can circle around too fast
@ -120,6 +122,12 @@ func TestConnectionsLeft(t *testing.T) {
break
}
require.NoError(err)
connsCount := <-connsCountCh
connectionCount++
require.Equal(connectionCount, connsCount)
require.Equal(connectionCount, pxy.ConnectionsCount())
written, err := conn.Write(echo)
require.NoError(err)
require.Equal(written, len(echo))
@ -129,9 +137,13 @@ func TestConnectionsLeft(t *testing.T) {
connsLeft := <-connsLeftCh
require.Equal(i-1, connsLeft)
require.Equal(i-1, pxy.ConnectionsLeft())
}
pxyCancel()
// Wait to ensure cleanup and that the second-start logic works
wg.Wait()
require.Equal(int32(0), pxy.ConnectionsLeft())
require.Equal(int32(0), pxy.ConnectionsCount())
}

Loading…
Cancel
Save