You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/daemon/worker/session/manager_test.go

480 lines
15 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package session
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/x509"
"errors"
"fmt"
"math/big"
"net"
"sync"
"testing"
"time"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/targets"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestManager_Get(t *testing.T) {
mockSessionClient := pbs.NewMockSessionServiceClient()
mockSessionClient.LookupSessionFn = func(context.Context, *pbs.LookupSessionRequest) (*pbs.LookupSessionResponse, error) {
return &pbs.LookupSessionResponse{
Authorization: &targets.SessionAuthorizationData{
SessionId: "foo",
Certificate: createTestCert(t),
},
Version: 1,
Expiration: timestamppb.New(time.Now().Add(time.Hour)),
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
}, nil
}
manager, err := NewManager(mockSessionClient)
require.NoError(t, err)
_, err = manager.LoadLocalSession(context.Background(), "foo", "worker id")
require.NoError(t, err)
assert.NotNil(t, manager.Get("foo"))
assert.Nil(t, manager.Get("unknown"))
}
func TestManager_DeleteLocalSession(t *testing.T) {
mockSessionClient := pbs.NewMockSessionServiceClient()
mockSessionClient.LookupSessionFn = func(_ context.Context, req *pbs.LookupSessionRequest) (*pbs.LookupSessionResponse, error) {
return &pbs.LookupSessionResponse{
Authorization: &targets.SessionAuthorizationData{
SessionId: req.GetSessionId(),
Certificate: createTestCert(t),
},
Version: 1,
Expiration: timestamppb.New(time.Now().Add(time.Hour)),
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
}, nil
}
manager, err := NewManager(mockSessionClient)
require.NoError(t, err)
_, err = manager.LoadLocalSession(context.Background(), "foo", "worker id")
require.NoError(t, err)
require.NotNil(t, manager.Get("foo"))
manager.DeleteLocalSession([]string{"foo"})
assert.Nil(t, manager.Get("foo"))
// A second call to DeleteLocalSession is ok.
manager.DeleteLocalSession([]string{"foo"})
assert.Nil(t, manager.Get("foo"))
}
func TestManager_RequestCloseConnections(t *testing.T) {
ctx := context.Background()
mockSessionClient := pbs.NewMockSessionServiceClient()
manager, err := NewManager(mockSessionClient)
require.NoError(t, err)
assert.False(t, manager.RequestCloseConnections(ctx, nil))
assert.False(t, manager.RequestCloseConnections(ctx, map[string]*ConnectionCloseData{}))
mockSessionClient.LookupSessionFn = func(_ context.Context, req *pbs.LookupSessionRequest) (*pbs.LookupSessionResponse, error) {
return &pbs.LookupSessionResponse{
Authorization: &targets.SessionAuthorizationData{
SessionId: req.GetSessionId(),
Certificate: createTestCert(t),
},
Version: 1,
Expiration: timestamppb.New(time.Now().Add(time.Hour)),
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
}, nil
}
session1, err := manager.LoadLocalSession(ctx, "sess1", "worker id")
require.NoError(t, err)
session2, err := manager.LoadLocalSession(ctx, "sess2", "worker id")
require.NoError(t, err)
session3, err := manager.LoadLocalSession(ctx, "sess3", "worker id")
require.NoError(t, err)
mockSessionClient.CloseConnectionFn = func(context.Context, *pbs.CloseConnectionRequest) (*pbs.CloseConnectionResponse, error) {
return nil, errors.New("error")
}
assert.False(t, manager.RequestCloseConnections(ctx, map[string]*ConnectionCloseData{"connection id": {SessionId: session1.GetId()}}))
mockSessionClient.CloseConnectionFn = func(_ context.Context, req *pbs.CloseConnectionRequest) (*pbs.CloseConnectionResponse, error) {
var data []*pbs.CloseConnectionResponseData
for _, r := range req.GetCloseRequestData() {
data = append(data, &pbs.CloseConnectionResponseData{
ConnectionId: r.GetConnectionId(),
Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED,
})
}
return &pbs.CloseConnectionResponse{CloseResponseData: data}, nil
}
// There is no connection information yet for this connection.
assert.False(t, manager.RequestCloseConnections(ctx, map[string]*ConnectionCloseData{
"random_connection_id": {SessionId: session1.GetId()},
}))
// Load the connection info into the local storage
mockSessionClient.AuthorizeConnectionFn = func(_ context.Context, req *pbs.AuthorizeConnectionRequest) (*pbs.AuthorizeConnectionResponse, error) {
return &pbs.AuthorizeConnectionResponse{
ConnectionId: fmt.Sprintf("connection_%s", req.GetSessionId()),
Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_AUTHORIZED,
ConnectionsLeft: -1,
}, nil
}
_, cancelFn := context.WithCancel(ctx)
c1, _, err := session1.RequestAuthorizeConnection(ctx, "worker id", cancelFn)
require.NoError(t, err)
assert.True(t, manager.RequestCloseConnections(ctx, map[string]*ConnectionCloseData{
c1.GetConnectionId(): {SessionId: session1.GetId()},
}))
c2, _, err := session2.RequestAuthorizeConnection(ctx, "worker id", cancelFn)
require.NoError(t, err)
c3, _, err := session3.RequestAuthorizeConnection(ctx, "worker id", cancelFn)
require.NoError(t, err)
assert.True(t, manager.RequestCloseConnections(ctx, map[string]*ConnectionCloseData{
c2.GetConnectionId(): {SessionId: session2.GetId()},
c3.GetConnectionId(): {SessionId: session3.GetId()},
}))
}
func TestManager_LoadLocalSession(t *testing.T) {
mockSessionClient := pbs.NewMockSessionServiceClient()
errorCases := []struct {
name string
inId string
inWorkerId string
controllerResponse func() (*pbs.LookupSessionResponse, error)
wantError string
}{
{
name: "no id",
inWorkerId: "worker id",
wantError: "id is not set",
},
{
name: "no worker id",
inId: "id",
wantError: "workerId is not set",
},
{
name: "controller error",
inId: "id",
inWorkerId: "worker id",
controllerResponse: func() (*pbs.LookupSessionResponse, error) {
return nil, errors.New("some error")
},
wantError: "some error",
},
{
name: "cant parse cert",
inId: "id",
inWorkerId: "worker id",
controllerResponse: func() (*pbs.LookupSessionResponse, error) {
return &pbs.LookupSessionResponse{
Authorization: &targets.SessionAuthorizationData{
SessionId: "foo",
Certificate: []byte("unparseable"),
},
Version: 1,
Expiration: timestamppb.New(time.Now().Add(time.Hour)),
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
}, nil
},
wantError: "error parsing session certificate",
},
{
name: "expired session",
inId: "id",
inWorkerId: "worker id",
controllerResponse: func() (*pbs.LookupSessionResponse, error) {
return &pbs.LookupSessionResponse{
Authorization: &targets.SessionAuthorizationData{
SessionId: "foo",
Certificate: createTestCert(t),
},
Version: 1,
Expiration: timestamppb.New(time.Now().Add(-time.Hour)),
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
}, nil
},
wantError: "session is expired",
},
}
for _, tc := range errorCases {
t.Run(tc.name, func(t *testing.T) {
mockSessionClient.LookupSessionFn = func(context.Context, *pbs.LookupSessionRequest) (*pbs.LookupSessionResponse, error) {
return tc.controllerResponse()
}
manager, err := NewManager(mockSessionClient)
require.NoError(t, err)
s, err := manager.LoadLocalSession(context.Background(), tc.inId, tc.inWorkerId)
require.Error(t, err)
assert.Nil(t, s)
assert.ErrorContains(t, err, tc.wantError)
})
}
t.Run("success", func(t *testing.T) {
expirationTime := time.Now().Add(time.Hour).UTC()
mockSessionClient.LookupSessionFn = func(context.Context, *pbs.LookupSessionRequest) (*pbs.LookupSessionResponse, error) {
return &pbs.LookupSessionResponse{
Authorization: &targets.SessionAuthorizationData{
SessionId: "foo",
Certificate: createTestCert(t),
},
ConnectionLimit: -1,
Version: 1,
Expiration: timestamppb.New(expirationTime),
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
}, nil
}
manager, err := NewManager(mockSessionClient)
require.NoError(t, err)
s, err := manager.LoadLocalSession(context.Background(), "foo", "worker id")
require.NoError(t, err)
assert.Equal(t, "foo", s.GetId())
assert.Equal(t, pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING, s.GetStatus())
assert.Empty(t, s.GetLocalConnections())
assert.Equal(t, int32(-1), s.GetConnectionLimit())
assert.Equal(t, expirationTime, s.GetExpiration())
})
}
func createTestCert(t *testing.T) []byte {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
SerialNumber: big.NewInt(0),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(5 * time.Minute),
BasicConstraintsValid: true,
IsCA: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
DNSNames: []string{"/tmp/boundary-opslistener-test0.sock", "/tmp/boundary-opslistener-test1.sock"},
}
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, pub, priv)
require.NoError(t, err)
return certBytes
}
func TestWorkerSetCloseTimeForResponse(t *testing.T) {
cases := []struct {
name string
sessionCloseInfo map[string][]*pbs.CloseConnectionResponseData
sessionInfoMap func() *sync.Map
expected []string
expectedClosed map[string]struct{}
expectedErr []error
}{
{
name: "basic",
sessionCloseInfo: map[string][]*pbs.CloseConnectionResponseData{
"one": {
{ConnectionId: "foo", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
},
"two": {
{ConnectionId: "bar", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
},
},
sessionInfoMap: func() *sync.Map {
m := new(sync.Map)
m.Store("one", &sess{
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
SessionId: "one",
}},
connInfoMap: map[string]*ConnInfo{
"foo": {Id: "foo"},
},
})
m.Store("two", &sess{
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
SessionId: "two",
}},
connInfoMap: map[string]*ConnInfo{
"bar": {Id: "bar"},
},
})
m.Store("three", &sess{
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
SessionId: "three",
}},
connInfoMap: map[string]*ConnInfo{
"baz": {Id: "baz"},
},
})
return m
},
expected: []string{"foo", "bar"},
expectedClosed: map[string]struct{}{
"foo": {},
"bar": {},
},
},
{
name: "not closed",
sessionCloseInfo: map[string][]*pbs.CloseConnectionResponseData{
"one": {
{ConnectionId: "foo", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
},
"two": {
{ConnectionId: "bar", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CONNECTED},
},
},
sessionInfoMap: func() *sync.Map {
m := new(sync.Map)
m.Store("one", &sess{
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
SessionId: "one",
}},
connInfoMap: map[string]*ConnInfo{
"foo": {Id: "foo"},
},
})
m.Store("two", &sess{
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
SessionId: "two",
}},
connInfoMap: map[string]*ConnInfo{
"bar": {Id: "bar"},
},
})
return m
},
expected: []string{"foo"},
expectedClosed: map[string]struct{}{
"foo": {},
},
},
{
name: "missing session",
sessionCloseInfo: map[string][]*pbs.CloseConnectionResponseData{
"one": {
{ConnectionId: "foo", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
},
"two": {
{ConnectionId: "bar", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
},
},
sessionInfoMap: func() *sync.Map {
m := new(sync.Map)
m.Store("one", &sess{
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
SessionId: "one",
}},
connInfoMap: map[string]*ConnInfo{
"foo": {Id: "foo"},
},
})
return m
},
expected: []string{"foo"},
expectedClosed: map[string]struct{}{
"foo": {},
},
expectedErr: []error{
errors.New(`could not find session ID "two" in local state after closing connections`),
},
},
{
name: "missing connection",
sessionCloseInfo: map[string][]*pbs.CloseConnectionResponseData{
"one": {
{ConnectionId: "foo", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
},
"two": {
{ConnectionId: "bar", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
},
},
sessionInfoMap: func() *sync.Map {
m := new(sync.Map)
m.Store("one", &sess{
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
SessionId: "one",
}},
connInfoMap: map[string]*ConnInfo{
"foo": {Id: "foo"},
},
sessionId: "one",
})
m.Store("two", &sess{
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
SessionId: "two",
}},
sessionId: "two",
})
return m
},
expected: []string{"foo"},
expectedClosed: map[string]struct{}{
"foo": {},
},
expectedErr: []error{
errors.New(`could not find connection ID "bar" for session ID "two" in local state`),
},
},
{
name: "empty",
sessionCloseInfo: make(map[string][]*pbs.CloseConnectionResponseData),
sessionInfoMap: func() *sync.Map {
m := new(sync.Map)
m.Store("one", &sess{
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
SessionId: "one",
}},
connInfoMap: map[string]*ConnInfo{
"foo": {Id: "foo"},
},
})
return m
},
expected: []string{},
expectedClosed: map[string]struct{}{},
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)
manager := &manager{sessionMap: new(sync.Map)}
manager.sessionMap = tc.sessionInfoMap()
actual, actualErr := setCloseTimeForResponse(manager, tc.sessionCloseInfo)
// Assert all close times were set
manager.ForEachLocalSession(func(value Session) bool {
t.Helper()
for _, ci := range value.GetLocalConnections() {
if _, ok := tc.expectedClosed[ci.Id]; ok {
require.NotEqual(time.Time{}, ci.CloseTime)
} else {
require.Equal(time.Time{}, ci.CloseTime)
}
}
return true
})
// Assert return values
require.ElementsMatch(tc.expected, actual)
require.ElementsMatch(tc.expectedErr, actualErr)
})
}
}