diff --git a/internal/daemon/worker/countingconn.go b/internal/daemon/worker/countingconn.go new file mode 100644 index 0000000000..a8c7b9bd06 --- /dev/null +++ b/internal/daemon/worker/countingconn.go @@ -0,0 +1,48 @@ +package worker + +import ( + "net" + "sync" +) + +type countingConn struct { + net.Conn + + bytesRead uint64 + bytesWritten uint64 + // Use mutex for counters as net.Conn methods may be called concurrently + // https://github.com/golang/go/issues/27203#issuecomment-415854958 + mu sync.Mutex +} + +// BytesRead reports the number of bytes read so far +func (c *countingConn) BytesRead() uint64 { + c.mu.Lock() + defer c.mu.Unlock() + return c.bytesRead +} + +// BytesWritten reports the number of bytes written so far +func (c *countingConn) BytesWritten() uint64 { + c.mu.Lock() + defer c.mu.Unlock() + return c.bytesWritten +} + +// Read wraps the embedded conn's Read and counts the number of bytes read. +func (c *countingConn) Read(in []byte) (int, error) { + n, err := c.Conn.Read(in) + c.mu.Lock() + c.bytesRead += uint64(n) + c.mu.Unlock() + return n, err +} + +// Write wraps the embedded conn's Write and counts the number of bytes read. +func (c *countingConn) Write(in []byte) (int, error) { + n, err := c.Conn.Write(in) + c.mu.Lock() + c.bytesWritten += uint64(n) + c.mu.Unlock() + return n, err +} diff --git a/internal/daemon/worker/countingconn_test.go b/internal/daemon/worker/countingconn_test.go new file mode 100644 index 0000000000..6e4f02b3fb --- /dev/null +++ b/internal/daemon/worker/countingconn_test.go @@ -0,0 +1,198 @@ +package worker + +import ( + "fmt" + "net" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCountingConn(t *testing.T) { + t.Parallel() + tests := []struct { + name string + writeBytes []byte + underlyingConn *testNetConn + }{ + { + name: "noErrors", + writeBytes: []byte("hello"), + underlyingConn: &testNetConn{ + bytesToRead: 100, + readErr: false, + writeErr: false, + closeErr: false, + }, + }, + { + name: "readErr", + writeBytes: []byte("hello"), + underlyingConn: &testNetConn{ + bytesToRead: 100, + readErr: true, + writeErr: false, + closeErr: false, + }, + }, + { + name: "writeErr", + writeBytes: []byte("hello"), + underlyingConn: &testNetConn{ + bytesToRead: 100, + readErr: false, + writeErr: true, + closeErr: false, + }, + }, + { + name: "closeErr", + writeBytes: []byte("hello"), + underlyingConn: &testNetConn{ + bytesToRead: 100, + readErr: false, + writeErr: false, + closeErr: true, + }, + }, + { + name: "allErr", + writeBytes: []byte("hello"), + underlyingConn: &testNetConn{ + bytesToRead: 100, + readErr: true, + writeErr: true, + closeErr: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conn := countingConn{Conn: tt.underlyingConn} + + readBytes := make([]byte, tt.underlyingConn.bytesToRead) + read, err := conn.Read(readBytes) + require.True(t, tt.underlyingConn.readCalled) + if tt.underlyingConn.readErr { + require.Error(t, err) + require.Equal(t, 1, read) + require.EqualValues(t, 1, conn.bytesRead) // We still capture the bytes on error + } else { + require.NoError(t, err) + require.Equal(t, tt.underlyingConn.bytesToRead, read) + require.EqualValues(t, tt.underlyingConn.bytesToRead, conn.bytesRead) + require.Len(t, readBytes, tt.underlyingConn.bytesToRead) + } + + written, err := conn.Write(tt.writeBytes) + require.True(t, tt.underlyingConn.writeCalled) + if tt.underlyingConn.writeErr { + require.Error(t, err) + require.Equal(t, 1, written) + require.EqualValues(t, 1, conn.bytesWritten) // We still capture the bytes on error + } else { + require.NoError(t, err) + require.Equal(t, len(tt.writeBytes), written) + require.EqualValues(t, len(tt.writeBytes), conn.bytesWritten) + } + + err = conn.Close() + require.True(t, tt.underlyingConn.closeCalled) + if tt.underlyingConn.closeErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestCountingConnConcurrentCalls(t *testing.T) { + t.Parallel() + + bytesToRead := 100 + bytesToWrite := []byte("hello") + + concurrentReads := 1000 + concurrentWrites := 1000 + + conn := &countingConn{Conn: &testNetConn{bytesToRead: bytesToRead}} + + wg := sync.WaitGroup{} + wg.Add(concurrentReads) + for i := 0; i < concurrentReads; i++ { + go func() { + defer wg.Done() + in := make([]byte, bytesToRead) + + read, err := conn.Read(in) + require.NoError(t, err) + require.Equal(t, bytesToRead, read) + require.Len(t, in, bytesToRead) + }() + } + + wg.Add(concurrentWrites) + for i := 0; i < concurrentWrites; i++ { + go func() { + defer wg.Done() + + written, err := conn.Write(bytesToWrite) + require.NoError(t, err) + require.Equal(t, len(bytesToWrite), written) + }() + } + + wg.Wait() + + require.EqualValues(t, concurrentReads*bytesToRead, conn.BytesRead()) + require.EqualValues(t, concurrentWrites*len(bytesToWrite), conn.BytesWritten()) +} + +type testNetConn struct { + net.Conn // So we don't have to implement the entire interface. + + // Test properties + bytesToRead int + readErr bool + writeErr bool + closeErr bool + + // Test results + readCalled bool + writeCalled bool + closeCalled bool +} + +func (t *testNetConn) Read(in []byte) (int, error) { + t.readCalled = true + if t.readErr { + return 1, fmt.Errorf("oops, read error") + } + + for i := 0; i < t.bytesToRead; i++ { + in[i] = 10 + } + + return int(t.bytesToRead), nil +} + +func (t *testNetConn) Write(in []byte) (int, error) { + t.writeCalled = true + if t.writeErr { + return 1, fmt.Errorf("oops, write error") + } + + return len(in), nil +} + +func (t *testNetConn) Close() error { + t.closeCalled = true + if t.closeErr { + return fmt.Errorf("oops, close error") + } + + return nil +} diff --git a/internal/daemon/worker/handler.go b/internal/daemon/worker/handler.go index 727531579d..7614a447d7 100644 --- a/internal/daemon/worker/handler.go +++ b/internal/daemon/worker/handler.go @@ -218,6 +218,8 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa return } event.WriteSysEvent(ctx, op, "connection successfully authorized", "session_id", sessionId, "connection_id", ci.Id) + + cc := &countingConn{Conn: websocket.NetConn(connCtx, conn, websocket.MessageBinary)} defer func() { if sessionManager.RequestCloseConnections(ctx, map[string]string{ci.Id: sess.GetId()}) { event.WriteSysEvent(ctx, op, "connection closed", "session_id", sessionId, "connection_id", ci.Id) @@ -240,7 +242,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa conf := proxyHandlers.Config{ UserClientIp: net.ParseIP(userClientIp), ClientAddress: clientAddr, - ClientConn: conn, + ClientConn: cc, RemoteEndpoint: sess.GetEndpoint(), Session: sess, ConnectionId: ci.Id, diff --git a/internal/daemon/worker/proxy/proxy.go b/internal/daemon/worker/proxy/proxy.go index 719cfc98c0..34190ea0f7 100644 --- a/internal/daemon/worker/proxy/proxy.go +++ b/internal/daemon/worker/proxy/proxy.go @@ -7,7 +7,6 @@ import ( "sync" "github.com/hashicorp/boundary/internal/daemon/worker/session" - "nhooyr.io/websocket" ) // Config provides the core parameters needed for a worker to create a proxy between @@ -20,7 +19,7 @@ type Config struct { // there are any load balancers or proxies between the user and the worker, // then it will be the address of the last one before the worker. ClientAddress *net.TCPAddr - ClientConn *websocket.Conn + ClientConn net.Conn RemoteEndpoint string Session session.Session diff --git a/internal/daemon/worker/proxy/proxy_test.go b/internal/daemon/worker/proxy/proxy_test.go index 3be549226a..3bec9e5841 100644 --- a/internal/daemon/worker/proxy/proxy_test.go +++ b/internal/daemon/worker/proxy/proxy_test.go @@ -35,7 +35,7 @@ func TestConfigValidate(t *testing.T) { { name: "missing-client-address", conf: Config{ - ClientConn: conn, + ClientConn: websocket.NetConn(context.Background(), conn, websocket.MessageBinary), RemoteEndpoint: "tcp://remote", Session: si, ConnectionId: "connection-id", @@ -58,7 +58,7 @@ func TestConfigValidate(t *testing.T) { name: "missing-remote-endpoint", conf: Config{ ClientAddress: clientAddr, - ClientConn: conn, + ClientConn: websocket.NetConn(context.Background(), conn, websocket.MessageBinary), Session: si, ConnectionId: "connection-id", }, @@ -69,7 +69,7 @@ func TestConfigValidate(t *testing.T) { name: "missing-session", conf: Config{ ClientAddress: clientAddr, - ClientConn: conn, + ClientConn: websocket.NetConn(context.Background(), conn, websocket.MessageBinary), RemoteEndpoint: "tcp://remote", ConnectionId: "connection-id", }, @@ -80,7 +80,7 @@ func TestConfigValidate(t *testing.T) { name: "missing-connection-id", conf: Config{ ClientAddress: clientAddr, - ClientConn: conn, + ClientConn: websocket.NetConn(context.Background(), conn, websocket.MessageBinary), RemoteEndpoint: "tcp://remote", Session: si, }, @@ -91,7 +91,7 @@ func TestConfigValidate(t *testing.T) { name: "valid", conf: Config{ ClientAddress: clientAddr, - ClientConn: conn, + ClientConn: websocket.NetConn(context.Background(), conn, websocket.MessageBinary), RemoteEndpoint: "tcp://remote", Session: si, ConnectionId: "connection-id", diff --git a/internal/daemon/worker/proxy/tcp/tcp.go b/internal/daemon/worker/proxy/tcp/tcp.go index ba9c60de39..f3ca66bba6 100644 --- a/internal/daemon/worker/proxy/tcp/tcp.go +++ b/internal/daemon/worker/proxy/tcp/tcp.go @@ -10,7 +10,6 @@ import ( "github.com/hashicorp/boundary/internal/daemon/worker/proxy" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" - "nhooyr.io/websocket" ) func init() { @@ -59,22 +58,19 @@ func handleProxy(ctx context.Context, conf proxy.Config, _ ...proxy.Option) erro return fmt.Errorf("error marking connection as connected: %w", err) } - // Get a wrapped net.Conn so we can use io.Copy - netConn := websocket.NetConn(ctx, conn, websocket.MessageBinary) - connWg := new(sync.WaitGroup) connWg.Add(2) go func() { defer connWg.Done() - _, _ = io.Copy(netConn, tcpRemoteConn) - _ = netConn.Close() + _, _ = io.Copy(conn, tcpRemoteConn) + _ = conn.Close() _ = tcpRemoteConn.Close() }() go func() { defer connWg.Done() - _, _ = io.Copy(tcpRemoteConn, netConn) + _, _ = io.Copy(tcpRemoteConn, conn) _ = tcpRemoteConn.Close() - _ = netConn.Close() + _ = conn.Close() }() connWg.Wait() return nil diff --git a/internal/daemon/worker/proxy/tcp/tcp_test.go b/internal/daemon/worker/proxy/tcp/tcp_test.go index 0f8cf9f05b..978461ab19 100644 --- a/internal/daemon/worker/proxy/tcp/tcp_test.go +++ b/internal/daemon/worker/proxy/tcp/tcp_test.go @@ -85,9 +85,10 @@ func TestHandleTcpProxyV1(t *testing.T) { _, _, err = s.RequestAuthorizeConnection(ctx, "workerid", connCancelFn) require.NoError(err) + conn := websocket.NetConn(ctx, proxyConn, websocket.MessageBinary) conf := proxy.Config{ ClientAddress: clientAddr, - ClientConn: proxyConn, + ClientConn: conn, RemoteEndpoint: fmt.Sprintf("tcp://localhost:%d", port), Session: s, ConnectionId: "mock-connection",