mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
491 lines
14 KiB
491 lines
14 KiB
package helper
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"reflect"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/boundary/api/targets"
|
|
"github.com/hashicorp/boundary/globals"
|
|
targetspb "github.com/hashicorp/boundary/internal/gen/controller/api/resources/targets"
|
|
"github.com/hashicorp/boundary/internal/proxy"
|
|
"github.com/hashicorp/boundary/internal/servers/controller/common"
|
|
"github.com/hashicorp/boundary/internal/servers/worker"
|
|
"github.com/hashicorp/boundary/internal/session"
|
|
"github.com/hashicorp/go-cleanhttp"
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-secure-stdlib/base62"
|
|
"github.com/mr-tron/base58"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/protobuf/proto"
|
|
"nhooyr.io/websocket"
|
|
"nhooyr.io/websocket/wspb"
|
|
)
|
|
|
|
const (
|
|
DefaultGracePeriod = time.Second * 15
|
|
expectConnectionStateOnControllerTimeout = time.Minute * 2
|
|
expectConnectionStateOnWorkerTimeout = DefaultGracePeriod * 2
|
|
|
|
// This is the interval that we check states on in the worker. It
|
|
// needs to be particularly granular to ensure that we allow for
|
|
// adequate time to catch any edge cases where the connection state
|
|
// disappears from the worker before we can update the state.
|
|
//
|
|
// As of this writing, the (hardcoded) status request interval on
|
|
// the worker is 2 seconds, and a session is removed from the state
|
|
// when it has no connections left on the *next* pass. A one second
|
|
// interval would only mean two chances to check with a high
|
|
// possibility of skew; while it seems okay I'm still not 100%
|
|
// comfortable with this little resolution.
|
|
expectConnectionStateOnWorkerInterval = time.Millisecond * 100
|
|
)
|
|
|
|
// TestSession represents an authorized session.
|
|
type TestSession struct {
|
|
sessionId string
|
|
workerAddr string
|
|
transport *http.Transport
|
|
tofuToken string
|
|
connectionsLeft int32
|
|
logger hclog.Logger
|
|
}
|
|
|
|
// NewTestSession authorizes a session and creates all of the data
|
|
// necessary to initialize.
|
|
func NewTestSession(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
logger hclog.Logger,
|
|
tcl *targets.Client,
|
|
targetId string,
|
|
) *TestSession {
|
|
t.Helper()
|
|
require := require.New(t)
|
|
sar, err := tcl.AuthorizeSession(ctx, "ttcp_1234567890")
|
|
require.NoError(err)
|
|
require.NotNil(sar)
|
|
|
|
s := &TestSession{
|
|
sessionId: sar.Item.SessionId,
|
|
logger: logger,
|
|
}
|
|
authzString := sar.GetItem().(*targets.SessionAuthorization).AuthorizationToken
|
|
marshaled, err := base58.FastBase58Decoding(authzString)
|
|
require.NoError(err)
|
|
require.NotZero(marshaled)
|
|
|
|
sessionAuthzData := new(targetspb.SessionAuthorizationData)
|
|
err = proto.Unmarshal(marshaled, sessionAuthzData)
|
|
require.NoError(err)
|
|
require.NotZero(sessionAuthzData.GetWorkerInfo())
|
|
|
|
s.workerAddr = sessionAuthzData.GetWorkerInfo()[0].GetAddress()
|
|
|
|
parsedCert, err := x509.ParseCertificate(sessionAuthzData.Certificate)
|
|
require.NoError(err)
|
|
require.Len(parsedCert.DNSNames, 1)
|
|
|
|
certPool := x509.NewCertPool()
|
|
certPool.AddCert(parsedCert)
|
|
tlsConf := &tls.Config{
|
|
Certificates: []tls.Certificate{
|
|
{
|
|
Certificate: [][]byte{sessionAuthzData.Certificate},
|
|
PrivateKey: ed25519.PrivateKey(sessionAuthzData.PrivateKey),
|
|
Leaf: parsedCert,
|
|
},
|
|
},
|
|
RootCAs: certPool,
|
|
ServerName: parsedCert.DNSNames[0],
|
|
MinVersion: tls.VersionTLS13,
|
|
}
|
|
|
|
s.transport = cleanhttp.DefaultTransport()
|
|
s.transport.DisableKeepAlives = false
|
|
s.transport.TLSClientConfig = tlsConf
|
|
s.transport.IdleConnTimeout = 0
|
|
|
|
return s
|
|
}
|
|
|
|
// connect returns a connected websocket for the stored session,
|
|
// connecting to the stored workerAddr with the configured transport.
|
|
//
|
|
// The returned (wrapped) net.Conn should be ready for communication.
|
|
func (s *TestSession) connect(ctx context.Context, t *testing.T) net.Conn {
|
|
t.Helper()
|
|
require := require.New(t)
|
|
conn, resp, err := websocket.Dial(
|
|
ctx,
|
|
fmt.Sprintf("wss://%s/v1/proxy", s.workerAddr),
|
|
&websocket.DialOptions{
|
|
HTTPClient: &http.Client{
|
|
Transport: s.transport,
|
|
},
|
|
Subprotocols: []string{globals.TcpProxyV1},
|
|
},
|
|
)
|
|
require.NoError(err)
|
|
require.NotNil(conn)
|
|
require.NotNil(resp)
|
|
require.Equal(resp.Header.Get("Sec-WebSocket-Protocol"), globals.TcpProxyV1)
|
|
|
|
// Send the handshake.
|
|
if s.tofuToken == "" {
|
|
s.tofuToken, err = base62.Random(20)
|
|
}
|
|
|
|
require.NoError(err)
|
|
handshake := proxy.ClientHandshake{TofuToken: s.tofuToken}
|
|
err = wspb.Write(ctx, conn, &handshake)
|
|
require.NoError(err)
|
|
|
|
// Receive/check the handshake
|
|
var handshakeResult proxy.HandshakeResult
|
|
err = wspb.Read(ctx, conn, &handshakeResult)
|
|
require.NoError(err)
|
|
|
|
// This is just a cursory check to make sure that the handshake is
|
|
// populated. We could check connections remaining too, but that
|
|
// could legitimately be a trivial (zero) value.
|
|
require.NotNil(handshakeResult.GetExpiration())
|
|
s.connectionsLeft = handshakeResult.GetConnectionsLeft()
|
|
|
|
return websocket.NetConn(ctx, conn, websocket.MessageBinary)
|
|
}
|
|
|
|
// ExpectConnectionStateOnController waits until all connections in a
|
|
// session have transitioned to a particular state on the controller.
|
|
func (s *TestSession) ExpectConnectionStateOnController(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
sessionRepoFn common.SessionRepoFactory,
|
|
expectState session.ConnectionStatus,
|
|
) {
|
|
t.Helper()
|
|
require := require.New(t)
|
|
assert := assert.New(t)
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, expectConnectionStateOnControllerTimeout)
|
|
defer cancel()
|
|
|
|
// This is just for initialization of the actual state set.
|
|
const sessionStatusUnknown session.ConnectionStatus = "unknown"
|
|
|
|
sessionRepo, err := sessionRepoFn()
|
|
require.NoError(err)
|
|
|
|
conns, err := sessionRepo.ListConnectionsBySessionId(ctx, s.sessionId)
|
|
require.NoError(err)
|
|
// To avoid misleading passing tests, we require this test be used
|
|
// with sessions with connections..
|
|
require.Greater(len(conns), 0, "should have at least one connection")
|
|
|
|
// Make a set of states, 1 per connection
|
|
actualStates := make([]session.ConnectionStatus, len(conns))
|
|
for i := range actualStates {
|
|
actualStates[i] = sessionStatusUnknown
|
|
}
|
|
|
|
// Make expect set for comparison
|
|
expectStates := make([]session.ConnectionStatus, len(conns))
|
|
for i := range expectStates {
|
|
expectStates[i] = expectState
|
|
}
|
|
|
|
for {
|
|
if ctx.Err() != nil {
|
|
break
|
|
}
|
|
|
|
for i, conn := range conns {
|
|
_, states, err := sessionRepo.LookupConnection(ctx, conn.PublicId, nil)
|
|
require.NoError(err)
|
|
// Look at the first state in the returned list, which will
|
|
// be the most recent state.
|
|
actualStates[i] = states[0].Status
|
|
}
|
|
|
|
if reflect.DeepEqual(expectStates, actualStates) {
|
|
break
|
|
}
|
|
|
|
time.Sleep(time.Second)
|
|
}
|
|
|
|
// "non-fatal" assert here, so that we can surface both timeouts
|
|
// and invalid state
|
|
assert.NoError(ctx.Err())
|
|
|
|
// Assert
|
|
require.Equal(expectStates, actualStates)
|
|
s.logger.Debug("successfully asserted all connection states on controller", "expected_states", expectStates, "actual_states", actualStates)
|
|
}
|
|
|
|
// ExpectConnectionStateOnWorker waits until all connections in a
|
|
// session have transitioned to a particular state on the worker.
|
|
func (s *TestSession) ExpectConnectionStateOnWorker(
|
|
ctx context.Context,
|
|
t *testing.T,
|
|
tw *worker.TestWorker,
|
|
expectState session.ConnectionStatus,
|
|
) {
|
|
t.Helper()
|
|
require := require.New(t)
|
|
assert := assert.New(t)
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, expectConnectionStateOnWorkerTimeout)
|
|
defer cancel()
|
|
|
|
// This is just for initialization of the actual state set.
|
|
const sessionStatusUnknown session.ConnectionStatus = "unknown"
|
|
|
|
connIds := s.testWorkerConnectionIds(t, tw)
|
|
// Make a set of states, 1 per connection
|
|
actualStates := make(map[string]session.ConnectionStatus, len(connIds))
|
|
for _, id := range connIds {
|
|
actualStates[id] = sessionStatusUnknown
|
|
}
|
|
|
|
// Make expect set for comparison
|
|
expectStates := make(map[string]session.ConnectionStatus, len(connIds))
|
|
for _, id := range connIds {
|
|
expectStates[id] = expectState
|
|
}
|
|
|
|
for {
|
|
if ctx.Err() != nil {
|
|
break
|
|
}
|
|
|
|
for _, conn := range s.testWorkerConnectionInfo(t, tw) {
|
|
actualStates[conn.Id] = session.ConnectionStatusFromProtoVal(conn.Status)
|
|
}
|
|
|
|
if reflect.DeepEqual(expectStates, actualStates) {
|
|
break
|
|
}
|
|
|
|
time.Sleep(expectConnectionStateOnWorkerInterval)
|
|
}
|
|
|
|
// "non-fatal" assert here, so that we can surface both timeouts
|
|
// and invalid state
|
|
assert.NoError(ctx.Err())
|
|
|
|
// Assert
|
|
require.Equal(expectStates, actualStates)
|
|
s.logger.Debug("successfully asserted all connection states on worker", "expected_states", expectStates, "actual_states", actualStates)
|
|
}
|
|
|
|
func (s *TestSession) testWorkerConnectionInfo(t *testing.T, tw *worker.TestWorker) map[string]worker.TestConnectionInfo {
|
|
t.Helper()
|
|
require := require.New(t)
|
|
si, ok := tw.LookupSession(s.sessionId)
|
|
// This is always an error if the session has been removed from the
|
|
// local state.
|
|
require.True(ok)
|
|
|
|
// Likewise, we require the helper to always be used with
|
|
// connections.
|
|
require.Greater(len(si.Connections), 0, "should have at least one connection")
|
|
|
|
return si.Connections
|
|
}
|
|
|
|
func (s *TestSession) testWorkerConnectionIds(t *testing.T, tw *worker.TestWorker) []string {
|
|
t.Helper()
|
|
conns := s.testWorkerConnectionInfo(t, tw)
|
|
result := make([]string, 0, len(conns))
|
|
for _, conn := range conns {
|
|
result = append(result, conn.Id)
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// TestSessionConnection abstracts a connected session.
|
|
type TestSessionConnection struct {
|
|
conn net.Conn
|
|
logger hclog.Logger
|
|
}
|
|
|
|
// Connect returns a TestSessionConnection for a TestSession. Check
|
|
// the unexported connect method for the lower-level details.
|
|
func (s *TestSession) Connect(
|
|
ctx context.Context,
|
|
t *testing.T, // Just to add cleanup
|
|
) *TestSessionConnection {
|
|
t.Helper()
|
|
require := require.New(t)
|
|
conn := s.connect(ctx, t)
|
|
require.NotNil(conn)
|
|
t.Cleanup(func() {
|
|
conn.Close()
|
|
})
|
|
require.NotNil(conn)
|
|
|
|
return &TestSessionConnection{
|
|
conn: conn,
|
|
logger: s.logger,
|
|
}
|
|
}
|
|
|
|
// testSendRecv runs a basic send/receive test over the returned
|
|
// connection, and returns whether or not all "pings" made it
|
|
// through.
|
|
//
|
|
// The test is a simple sequence number, ticking up every second to
|
|
// max. The passed in conn is expected to copy whatever it is
|
|
// received.
|
|
func (c *TestSessionConnection) testSendRecv(t *testing.T) bool {
|
|
t.Helper()
|
|
require := require.New(t)
|
|
|
|
// This is a fairly arbitrary value, as the send/recv is
|
|
// instantaneous. The main key here is just to make sure that we do
|
|
// it a reasonable amount of times to know the connection is
|
|
// stable.
|
|
const testSendRecvSendMax = 100
|
|
for i := uint32(0); i < testSendRecvSendMax; i++ {
|
|
// Shuttle over the sequence number as base64.
|
|
err := binary.Write(c.conn, binary.LittleEndian, i)
|
|
if err != nil {
|
|
c.logger.Debug("received error during write", "err", err)
|
|
if errors.Is(err, net.ErrClosed) ||
|
|
errors.Is(err, io.EOF) ||
|
|
errors.Is(err, websocket.CloseError{Code: websocket.StatusPolicyViolation, Reason: "timed out"}) {
|
|
return false
|
|
}
|
|
|
|
require.FailNow(err.Error())
|
|
}
|
|
|
|
// Read it back
|
|
var j uint32
|
|
err = binary.Read(c.conn, binary.LittleEndian, &j)
|
|
if err != nil {
|
|
c.logger.Debug("received error during read", "err", err, "num_successfully_sent", i)
|
|
if errors.Is(err, net.ErrClosed) ||
|
|
errors.Is(err, io.EOF) ||
|
|
errors.Is(err, websocket.CloseError{Code: websocket.StatusPolicyViolation, Reason: "timed out"}) {
|
|
return false
|
|
}
|
|
|
|
require.FailNow(err.Error())
|
|
}
|
|
|
|
require.Equal(j, i)
|
|
|
|
// Sleep 1s
|
|
// time.Sleep(time.Second)
|
|
}
|
|
|
|
c.logger.Debug("finished send/recv successfully", "num_successfully_sent", testSendRecvSendMax)
|
|
return true
|
|
}
|
|
|
|
// TestSendRecvAll asserts that we were able to send/recv all pings
|
|
// over the test connection.
|
|
func (c *TestSessionConnection) TestSendRecvAll(t *testing.T) {
|
|
t.Helper()
|
|
require.True(t, c.testSendRecv(t))
|
|
c.logger.Debug("successfully asserted send/recv as passing")
|
|
}
|
|
|
|
// TestSendRecvFail asserts that we were able to send/recv all pings
|
|
// over the test connection.
|
|
func (c *TestSessionConnection) TestSendRecvFail(t *testing.T) {
|
|
t.Helper()
|
|
require.False(t, c.testSendRecv(t))
|
|
c.logger.Debug("successfully asserted send/recv as failing")
|
|
}
|
|
|
|
// TestTcpServer defines a simple testing TCP server. Data is simply
|
|
// copied from its input to its output.
|
|
type TestTcpServer struct {
|
|
logger hclog.Logger
|
|
ln net.Listener
|
|
conns map[string]net.Conn
|
|
}
|
|
|
|
// Port returns the port number for the test server listener.
|
|
func (ts *TestTcpServer) Port() uint32 {
|
|
_, portS, err := net.SplitHostPort(ts.ln.Addr().String())
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if portS == "" {
|
|
panic("empty port in what should be TCP listener")
|
|
}
|
|
|
|
port, err := strconv.Atoi(portS)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if port < 1 {
|
|
panic("zero or negative port in what should be a TCP listener")
|
|
}
|
|
|
|
return uint32(port)
|
|
}
|
|
|
|
// Close closes the listener and all active connections.
|
|
func (ts *TestTcpServer) Close() {
|
|
ts.ln.Close()
|
|
for _, conn := range ts.conns {
|
|
conn.Close()
|
|
}
|
|
}
|
|
|
|
func (ts *TestTcpServer) run() {
|
|
for {
|
|
conn, err := ts.ln.Accept()
|
|
if err != nil {
|
|
if !errors.Is(err, net.ErrClosed) {
|
|
ts.logger.Error("Accept() error in TestTcpServer", "err", err)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
ts.conns[conn.RemoteAddr().String()] = conn
|
|
|
|
go func(c net.Conn) {
|
|
io.Copy(c, c)
|
|
c.Close()
|
|
}(conn)
|
|
}
|
|
}
|
|
|
|
// NewTestTcpServer creates a TestTcpServer and runs it.
|
|
func NewTestTcpServer(t *testing.T, logger hclog.Logger) *TestTcpServer {
|
|
t.Helper()
|
|
require := require.New(t)
|
|
ts := &TestTcpServer{
|
|
logger: logger,
|
|
conns: make(map[string]net.Conn),
|
|
}
|
|
var err error
|
|
ts.ln, err = net.Listen("tcp", ":0")
|
|
require.NoError(err)
|
|
|
|
go ts.run()
|
|
return ts
|
|
}
|