Add the first iteration of the recorder cache and wire it in

pull/3251/head
Todd 3 years ago committed by Timothy Messier
parent 397a697731
commit bd2bca987d
No known key found for this signature in database
GPG Key ID: EFD2F184F7600572

@ -294,7 +294,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa
conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "error getting decryption function")
event.WriteError(ctx, op, err)
}
runProxy, err := handleProxyFn(ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx)
runProxy, err := handleProxyFn(ctx, decryptFn, cc, pDialer, acResp.GetConnectionId(), protocolCtx, w.recorderCache)
if err != nil {
conn.Close(proxyHandlers.WebsocketStatusProtocolSetupError, "unable to setup proxying")
event.WriteError(ctx, op, err)

@ -31,6 +31,9 @@ var (
GetHandler = tcpOnly
)
// RecordingManager allows a handler for a protocol that supports recording.
type RecordingManager any
// DecryptFn decrypts the provided bytes into a proto.Message
type DecryptFn func(ctx context.Context, from []byte, to proto.Message) error
@ -43,7 +46,7 @@ type ProxyConnFn func(ctx context.Context)
// be nil. If there is no error ProxyConnFn must be set. When Handler has
// returned, it is expected that the initial connection to the endpoint has been
// established.
type Handler func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error)
type Handler func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error)
func RegisterHandler(protocol string, handler Handler) error {
_, loaded := handlers.LoadOrStore(protocol, handler)

@ -17,7 +17,7 @@ import (
func TestRegisterHandler(t *testing.T) {
assert, require := assert.New(t), require.New(t)
fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) {
fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) {
return nil, nil
}
oldHandler := handlers
@ -40,7 +40,7 @@ func TestRegisterHandler(t *testing.T) {
func TestAlwaysTcpGetHandler(t *testing.T) {
assert, require := assert.New(t), require.New(t)
fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any) (ProxyConnFn, error) {
fn := func(context.Context, DecryptFn, net.Conn, *ProxyDialer, string, *anypb.Any, RecordingManager) (ProxyConnFn, error) {
return nil, nil
}
oldHandler := handlers

@ -27,7 +27,7 @@ func init() {
// handleProxy returns a ProxyConnFn which starts the copy between the
// connections and blocks until an error (EOF on happy path) is received on
// either connection.
func handleProxy(ctx context.Context, _ proxy.DecryptFn, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any) (proxy.ProxyConnFn, error) {
func handleProxy(ctx context.Context, _ proxy.DecryptFn, conn net.Conn, out *proxy.ProxyDialer, connId string, _ *anypb.Any, _ proxy.RecordingManager) (proxy.ProxyConnFn, error) {
const op = "tcp.HandleProxy"
switch {
case conn == nil:

@ -88,7 +88,7 @@ func TestHandleProxy_Errors(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
fn, err := handleProxy(context.Background(), nil, tc.conn, tc.dialer, tc.connId, tc.protocolCtx)
fn, err := handleProxy(context.Background(), nil, tc.conn, tc.dialer, tc.connId, tc.protocolCtx, nil)
if tc.wantError {
assert.Error(t, err)
assert.Nil(t, fn)
@ -166,7 +166,7 @@ func TestHandleTcpProxyV1(t *testing.T) {
conn := websocket.NetConn(ctx, proxyConn, websocket.MessageBinary)
go func() {
fn, err := handleProxy(ctx, nil, conn, tDial, resp.GetConnectionId(), resp.GetProtocolContext())
fn, err := handleProxy(ctx, nil, conn, tDial, resp.GetConnectionId(), resp.GetProtocolContext(), nil)
t.Cleanup(func() {
// Use of the t.Cleanup is so we can check the state of the returned
// error since it isn't valid to call `t.FailNow()` from a goroutine.

@ -69,10 +69,16 @@ type downstreamers interface {
RootId() string
}
// recorderCache updates the status updates with relevant recording
// information
type recorderCache any
// reverseConnReceiverFactory provides a simple factory which a Worker can use to
// create its reverseConnReceiver
var reverseConnReceiverFactory func() reverseConnReceiver
var recorderCacheFactory func() recorderCache
var initializeReverseGrpcClientCollectors = noopInitializePromCollectors
func noopInitializePromCollectors(r prometheus.Registerer) {}
@ -103,6 +109,8 @@ type Worker struct {
sessionManager session.Manager
recorderCache recorderCache
controllerStatusConn *atomic.Value
everAuthenticated *ua.Uint32
lastStatusSuccess *atomic.Value
@ -182,6 +190,10 @@ func New(conf *Config) (*Worker, error) {
w.downstreamReceiver = reverseConnReceiverFactory()
}
if recorderCacheFactory != nil {
w.recorderCache = recorderCacheFactory()
}
w.lastStatusSuccess.Store((*LastStatusInformation)(nil))
scheme := strconv.FormatInt(time.Now().UnixNano(), 36)
controllerResolver := manual.NewBuilderWithScheme(scheme)

@ -84,6 +84,7 @@ const (
// client and server error codes
Unauthorized Code = 401 // Unauthorized represents the operation is unauthorized
Forbidden Code = 403 // Forbidden represents the operation is forbidden
NotFound Code = 404 // NotFound represents an operation which is unable to find the requested item.
Conflict Code = 409 // Conflict represents the operation failed due to failed pre-condition or was aborted.
Internal Code = 500 // InternalError represents the system encountered an unexpected condition.

Loading…
Cancel
Save