feat(worker): net.Conn implementation to count Bytes Read and Written

This commit introduces a new net.Conn implementation (countingConn) that
will keep track of bytes read and written for a particular client
connection. This is then implemented into the worker code using
interface composition.

Refs: #2501
pull/2575/head
Hugo 4 years ago committed by Hugo Vieira
parent f114e20c6d
commit 1bb80624b8

@ -0,0 +1,48 @@
package worker
import (
"net"
"sync"
)
type countingConn struct {
net.Conn
bytesRead uint64
bytesWritten uint64
// Use mutex for counters as net.Conn methods may be called concurrently
// https://github.com/golang/go/issues/27203#issuecomment-415854958
mu sync.Mutex
}
// BytesRead reports the number of bytes read so far
func (c *countingConn) BytesRead() uint64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.bytesRead
}
// BytesWritten reports the number of bytes written so far
func (c *countingConn) BytesWritten() uint64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.bytesWritten
}
// Read wraps the embedded conn's Read and counts the number of bytes read.
func (c *countingConn) Read(in []byte) (int, error) {
n, err := c.Conn.Read(in)
c.mu.Lock()
c.bytesRead += uint64(n)
c.mu.Unlock()
return n, err
}
// Write wraps the embedded conn's Write and counts the number of bytes read.
func (c *countingConn) Write(in []byte) (int, error) {
n, err := c.Conn.Write(in)
c.mu.Lock()
c.bytesWritten += uint64(n)
c.mu.Unlock()
return n, err
}

@ -0,0 +1,198 @@
package worker
import (
"fmt"
"net"
"sync"
"testing"
"github.com/stretchr/testify/require"
)
func TestCountingConn(t *testing.T) {
t.Parallel()
tests := []struct {
name string
writeBytes []byte
underlyingConn *testNetConn
}{
{
name: "noErrors",
writeBytes: []byte("hello"),
underlyingConn: &testNetConn{
bytesToRead: 100,
readErr: false,
writeErr: false,
closeErr: false,
},
},
{
name: "readErr",
writeBytes: []byte("hello"),
underlyingConn: &testNetConn{
bytesToRead: 100,
readErr: true,
writeErr: false,
closeErr: false,
},
},
{
name: "writeErr",
writeBytes: []byte("hello"),
underlyingConn: &testNetConn{
bytesToRead: 100,
readErr: false,
writeErr: true,
closeErr: false,
},
},
{
name: "closeErr",
writeBytes: []byte("hello"),
underlyingConn: &testNetConn{
bytesToRead: 100,
readErr: false,
writeErr: false,
closeErr: true,
},
},
{
name: "allErr",
writeBytes: []byte("hello"),
underlyingConn: &testNetConn{
bytesToRead: 100,
readErr: true,
writeErr: true,
closeErr: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conn := countingConn{Conn: tt.underlyingConn}
readBytes := make([]byte, tt.underlyingConn.bytesToRead)
read, err := conn.Read(readBytes)
require.True(t, tt.underlyingConn.readCalled)
if tt.underlyingConn.readErr {
require.Error(t, err)
require.Equal(t, 1, read)
require.EqualValues(t, 1, conn.bytesRead) // We still capture the bytes on error
} else {
require.NoError(t, err)
require.Equal(t, tt.underlyingConn.bytesToRead, read)
require.EqualValues(t, tt.underlyingConn.bytesToRead, conn.bytesRead)
require.Len(t, readBytes, tt.underlyingConn.bytesToRead)
}
written, err := conn.Write(tt.writeBytes)
require.True(t, tt.underlyingConn.writeCalled)
if tt.underlyingConn.writeErr {
require.Error(t, err)
require.Equal(t, 1, written)
require.EqualValues(t, 1, conn.bytesWritten) // We still capture the bytes on error
} else {
require.NoError(t, err)
require.Equal(t, len(tt.writeBytes), written)
require.EqualValues(t, len(tt.writeBytes), conn.bytesWritten)
}
err = conn.Close()
require.True(t, tt.underlyingConn.closeCalled)
if tt.underlyingConn.closeErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
func TestCountingConnConcurrentCalls(t *testing.T) {
t.Parallel()
bytesToRead := 100
bytesToWrite := []byte("hello")
concurrentReads := 1000
concurrentWrites := 1000
conn := &countingConn{Conn: &testNetConn{bytesToRead: bytesToRead}}
wg := sync.WaitGroup{}
wg.Add(concurrentReads)
for i := 0; i < concurrentReads; i++ {
go func() {
defer wg.Done()
in := make([]byte, bytesToRead)
read, err := conn.Read(in)
require.NoError(t, err)
require.Equal(t, bytesToRead, read)
require.Len(t, in, bytesToRead)
}()
}
wg.Add(concurrentWrites)
for i := 0; i < concurrentWrites; i++ {
go func() {
defer wg.Done()
written, err := conn.Write(bytesToWrite)
require.NoError(t, err)
require.Equal(t, len(bytesToWrite), written)
}()
}
wg.Wait()
require.EqualValues(t, concurrentReads*bytesToRead, conn.BytesRead())
require.EqualValues(t, concurrentWrites*len(bytesToWrite), conn.BytesWritten())
}
type testNetConn struct {
net.Conn // So we don't have to implement the entire interface.
// Test properties
bytesToRead int
readErr bool
writeErr bool
closeErr bool
// Test results
readCalled bool
writeCalled bool
closeCalled bool
}
func (t *testNetConn) Read(in []byte) (int, error) {
t.readCalled = true
if t.readErr {
return 1, fmt.Errorf("oops, read error")
}
for i := 0; i < t.bytesToRead; i++ {
in[i] = 10
}
return int(t.bytesToRead), nil
}
func (t *testNetConn) Write(in []byte) (int, error) {
t.writeCalled = true
if t.writeErr {
return 1, fmt.Errorf("oops, write error")
}
return len(in), nil
}
func (t *testNetConn) Close() error {
t.closeCalled = true
if t.closeErr {
return fmt.Errorf("oops, close error")
}
return nil
}

@ -218,6 +218,8 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
return
}
event.WriteSysEvent(ctx, op, "connection successfully authorized", "session_id", sessionId, "connection_id", ci.Id)
cc := &countingConn{Conn: websocket.NetConn(connCtx, conn, websocket.MessageBinary)}
defer func() {
if sessionManager.RequestCloseConnections(ctx, map[string]string{ci.Id: sess.GetId()}) {
event.WriteSysEvent(ctx, op, "connection closed", "session_id", sessionId, "connection_id", ci.Id)
@ -240,7 +242,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
conf := proxyHandlers.Config{
UserClientIp: net.ParseIP(userClientIp),
ClientAddress: clientAddr,
ClientConn: conn,
ClientConn: cc,
RemoteEndpoint: sess.GetEndpoint(),
Session: sess,
ConnectionId: ci.Id,

@ -7,7 +7,6 @@ import (
"sync"
"github.com/hashicorp/boundary/internal/daemon/worker/session"
"nhooyr.io/websocket"
)
// Config provides the core parameters needed for a worker to create a proxy between
@ -20,7 +19,7 @@ type Config struct {
// there are any load balancers or proxies between the user and the worker,
// then it will be the address of the last one before the worker.
ClientAddress *net.TCPAddr
ClientConn *websocket.Conn
ClientConn net.Conn
RemoteEndpoint string
Session session.Session

@ -35,7 +35,7 @@ func TestConfigValidate(t *testing.T) {
{
name: "missing-client-address",
conf: Config{
ClientConn: conn,
ClientConn: websocket.NetConn(context.Background(), conn, websocket.MessageBinary),
RemoteEndpoint: "tcp://remote",
Session: si,
ConnectionId: "connection-id",
@ -58,7 +58,7 @@ func TestConfigValidate(t *testing.T) {
name: "missing-remote-endpoint",
conf: Config{
ClientAddress: clientAddr,
ClientConn: conn,
ClientConn: websocket.NetConn(context.Background(), conn, websocket.MessageBinary),
Session: si,
ConnectionId: "connection-id",
},
@ -69,7 +69,7 @@ func TestConfigValidate(t *testing.T) {
name: "missing-session",
conf: Config{
ClientAddress: clientAddr,
ClientConn: conn,
ClientConn: websocket.NetConn(context.Background(), conn, websocket.MessageBinary),
RemoteEndpoint: "tcp://remote",
ConnectionId: "connection-id",
},
@ -80,7 +80,7 @@ func TestConfigValidate(t *testing.T) {
name: "missing-connection-id",
conf: Config{
ClientAddress: clientAddr,
ClientConn: conn,
ClientConn: websocket.NetConn(context.Background(), conn, websocket.MessageBinary),
RemoteEndpoint: "tcp://remote",
Session: si,
},
@ -91,7 +91,7 @@ func TestConfigValidate(t *testing.T) {
name: "valid",
conf: Config{
ClientAddress: clientAddr,
ClientConn: conn,
ClientConn: websocket.NetConn(context.Background(), conn, websocket.MessageBinary),
RemoteEndpoint: "tcp://remote",
Session: si,
ConnectionId: "connection-id",

@ -10,7 +10,6 @@ import (
"github.com/hashicorp/boundary/internal/daemon/worker/proxy"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"nhooyr.io/websocket"
)
func init() {
@ -59,22 +58,19 @@ func handleProxy(ctx context.Context, conf proxy.Config, _ ...proxy.Option) erro
return fmt.Errorf("error marking connection as connected: %w", err)
}
// Get a wrapped net.Conn so we can use io.Copy
netConn := websocket.NetConn(ctx, conn, websocket.MessageBinary)
connWg := new(sync.WaitGroup)
connWg.Add(2)
go func() {
defer connWg.Done()
_, _ = io.Copy(netConn, tcpRemoteConn)
_ = netConn.Close()
_, _ = io.Copy(conn, tcpRemoteConn)
_ = conn.Close()
_ = tcpRemoteConn.Close()
}()
go func() {
defer connWg.Done()
_, _ = io.Copy(tcpRemoteConn, netConn)
_, _ = io.Copy(tcpRemoteConn, conn)
_ = tcpRemoteConn.Close()
_ = netConn.Close()
_ = conn.Close()
}()
connWg.Wait()
return nil

@ -85,9 +85,10 @@ func TestHandleTcpProxyV1(t *testing.T) {
_, _, err = s.RequestAuthorizeConnection(ctx, "workerid", connCancelFn)
require.NoError(err)
conn := websocket.NetConn(ctx, proxyConn, websocket.MessageBinary)
conf := proxy.Config{
ClientAddress: clientAddr,
ClientConn: proxyConn,
ClientConn: conn,
RemoteEndpoint: fmt.Sprintf("tcp://localhost:%d", port),
Session: s,
ConnectionId: "mock-connection",

Loading…
Cancel
Save