Pass both the control and data context into handle

pull/3251/head
Todd 3 years ago committed by Timothy Messier
parent 8c96d3ef9b
commit fb1ea92a30
No known key found for this signature in database
GPG Key ID: EFD2F184F7600572

@ -294,7 +294,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "error getting decryption function")
event.WriteError(ctx, op, err)
}
runProxy, err := handleProxyFn(ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx, w.recorderManager)
runProxy, err := handleProxyFn(ctx, ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx, w.recorderManager)
if err != nil {
conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "unable to setup proxying")
event.WriteError(ctx, op, err)
@ -321,7 +321,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
return
}
runProxy(ctx)
runProxy()
}, nil
}

@ -39,14 +39,14 @@ type DecryptFn func(ctx context.Context, from []byte, to proto.Message) error
// ProxyConnFn is called after the call to ConnectConnection on the cluster.
// ProxyConnFn blocks until the specific request that is being proxied is finished
type ProxyConnFn func(ctx context.Context)
type ProxyConnFn func()
// Handler is the type that all proxies need to implement to be called by the worker
// when a new client connection is created. If there is an error ProxyConnFn must
// be nil. If there is no error ProxyConnFn must be set. When Handler has
// returned, it is expected that the initial connection to the endpoint has been
// established.
type Handler func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error)
type Handler func(controlCtx context.Context, dataCtx context.Context, df DecryptFn, c net.Conn, pd *ProxyDialer, connId string, pb *anypb.Any, rm RecordingManager) (ProxyConnFn, error)
func RegisterHandler(protocol string, handler Handler) error {
_, loaded := handlers.LoadOrStore(protocol, handler)

@ -4,7 +4,6 @@
package proxy
import (
"context"
"net/http"
"sync/atomic"
)
@ -35,9 +34,9 @@ func ProxyHandlerCounter(h http.Handler) http.Handler {
// proxyConnFnCounter wraps a ProxyState and keeps the proxyCount incremented
// while it runs.
func proxyConnFnCounter(fn ProxyConnFn) ProxyConnFn {
return func(ctx context.Context) {
return func() {
ProxyState.proxyCount.Add(1)
defer ProxyState.proxyCount.Add(-1)
fn(ctx)
fn()
}
}

@ -4,7 +4,6 @@
package proxy
import (
"context"
"net/http"
"net/http/httptest"
"testing"
@ -22,9 +21,9 @@ func TestProxyStateHelpers(t *testing.T) {
t.Run("proxyConnFnCounter", func(t *testing.T) {
assert.EqualValues(t, 0, ProxyState.CurrentProxiedConnections())
proxyConnFnCounter(func(context.Context) {
proxyConnFnCounter(func() {
assert.EqualValues(t, 1, ProxyState.CurrentProxiedConnections())
})(context.Background())
})()
assert.EqualValues(t, 0, ProxyState.CurrentProxiedConnections())
})

@ -17,7 +17,7 @@ import (
func TestRegisterHandler(t *testing.T) {
assert, require := assert.New(t), require.New(t)
fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) {
fn := func(context.Context, context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) {
return nil, nil
}
oldHandler := handlers
@ -40,7 +40,7 @@ func TestRegisterHandler(t *testing.T) {
func TestAlwaysTcpGetHandler(t *testing.T) {
assert, require := assert.New(t), require.New(t)
fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) {
fn := func(context.Context, context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) {
return nil, nil
}
oldHandler := handlers

@ -27,22 +27,22 @@ func init() {
// handleProxy returns a ProxyConnFn which starts the copy between the
// connections and blocks until an error (EOF on happy path) is received on
// either connection.
func handleProxy(ctx context.Context, _ proxy.DecryptFn, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any, _ proxy.RecordingManager) (proxy.ProxyConnFn, error) {
func handleProxy(controlCtx context.Context, _ context.Context, _ proxy.DecryptFn, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any, _ proxy.RecordingManager) (proxy.ProxyConnFn, error) {
const op = "tcp.HandleProxy"
switch {
case conn == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "conn is nil")
return nil, errors.New(controlCtx, errors.InvalidParameter, op, "conn is nil")
case out == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "proxy dialer is nil")
return nil, errors.New(controlCtx, errors.InvalidParameter, op, "proxy dialer is nil")
case len(connId) == 0:
return nil, errors.New(ctx, errors.InvalidParameter, op, "connection id is empty")
return nil, errors.New(controlCtx, errors.InvalidParameter, op, "connection id is empty")
}
remoteConn, err := out.Dial(ctx)
remoteConn, err := out.Dial(controlCtx)
if err != nil {
return nil, err
}
return func(ctx context.Context) {
return func() {
connWg := new(sync.WaitGroup)
connWg.Add(2)
go func() {

@ -88,7 +88,8 @@ func TestHandleProxy_Errors(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
fn, err := handleProxy(context.Background(), nil, tc.conn, tc.dialer, tc.connId, tc.protocolCtx, nil)
ctx := context.Background()
fn, err := handleProxy(ctx, ctx, nil, tc.conn, tc.dialer, tc.connId, tc.protocolCtx, nil)
if tc.wantError {
assert.Error(t, err)
assert.Nil(t, fn)
@ -166,14 +167,14 @@ func TestHandleTcpProxyV1(t *testing.T) {
conn := websocket.NetConn(ctx, proxyConn, websocket.MessageBinary)
go func() {
fn, err := handleProxy(ctx, nil, conn, tDial, resp.GetConnectionId(), resp.GetProtocolContext(), nil)
fn, err := handleProxy(ctx, context.Background(), nil, conn, tDial, resp.GetConnectionId(), resp.GetProtocolContext(), nil)
t.Cleanup(func() {
// Use of the t.Cleanup is so we can check the state of the returned
// error since it isn't valid to call `t.FailNow()` from a goroutine.
// https://pkg.go.dev/testing#T.FailNow
require.NoError(err)
})
fn(context.Background())
fn()
}()
// wait for HandleTcpProxyV1 to dial endpoint

Loading…
Cancel
Save