mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
480 lines
15 KiB
480 lines
15 KiB
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package session
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"math/big"
|
|
"net"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
|
|
"github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/targets"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
)
|
|
|
|
func TestManager_Get(t *testing.T) {
|
|
mockSessionClient := pbs.NewMockSessionServiceClient()
|
|
mockSessionClient.LookupSessionFn = func(context.Context, *pbs.LookupSessionRequest) (*pbs.LookupSessionResponse, error) {
|
|
return &pbs.LookupSessionResponse{
|
|
Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: "foo",
|
|
Certificate: createTestCert(t),
|
|
},
|
|
Version: 1,
|
|
Expiration: timestamppb.New(time.Now().Add(time.Hour)),
|
|
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
|
|
}, nil
|
|
}
|
|
manager, err := NewManager(mockSessionClient)
|
|
require.NoError(t, err)
|
|
_, err = manager.LoadLocalSession(context.Background(), "foo", "worker id")
|
|
require.NoError(t, err)
|
|
|
|
assert.NotNil(t, manager.Get("foo"))
|
|
assert.Nil(t, manager.Get("unknown"))
|
|
}
|
|
|
|
func TestManager_DeleteLocalSession(t *testing.T) {
|
|
mockSessionClient := pbs.NewMockSessionServiceClient()
|
|
mockSessionClient.LookupSessionFn = func(_ context.Context, req *pbs.LookupSessionRequest) (*pbs.LookupSessionResponse, error) {
|
|
return &pbs.LookupSessionResponse{
|
|
Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: req.GetSessionId(),
|
|
Certificate: createTestCert(t),
|
|
},
|
|
Version: 1,
|
|
Expiration: timestamppb.New(time.Now().Add(time.Hour)),
|
|
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
|
|
}, nil
|
|
}
|
|
manager, err := NewManager(mockSessionClient)
|
|
require.NoError(t, err)
|
|
_, err = manager.LoadLocalSession(context.Background(), "foo", "worker id")
|
|
require.NoError(t, err)
|
|
|
|
require.NotNil(t, manager.Get("foo"))
|
|
manager.DeleteLocalSession([]string{"foo"})
|
|
assert.Nil(t, manager.Get("foo"))
|
|
// A second call to DeleteLocalSession is ok.
|
|
manager.DeleteLocalSession([]string{"foo"})
|
|
assert.Nil(t, manager.Get("foo"))
|
|
}
|
|
|
|
func TestManager_RequestCloseConnections(t *testing.T) {
|
|
ctx := context.Background()
|
|
mockSessionClient := pbs.NewMockSessionServiceClient()
|
|
|
|
manager, err := NewManager(mockSessionClient)
|
|
require.NoError(t, err)
|
|
assert.False(t, manager.RequestCloseConnections(ctx, nil))
|
|
assert.False(t, manager.RequestCloseConnections(ctx, map[string]*ConnectionCloseData{}))
|
|
|
|
mockSessionClient.LookupSessionFn = func(_ context.Context, req *pbs.LookupSessionRequest) (*pbs.LookupSessionResponse, error) {
|
|
return &pbs.LookupSessionResponse{
|
|
Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: req.GetSessionId(),
|
|
Certificate: createTestCert(t),
|
|
},
|
|
Version: 1,
|
|
Expiration: timestamppb.New(time.Now().Add(time.Hour)),
|
|
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
|
|
}, nil
|
|
}
|
|
session1, err := manager.LoadLocalSession(ctx, "sess1", "worker id")
|
|
require.NoError(t, err)
|
|
session2, err := manager.LoadLocalSession(ctx, "sess2", "worker id")
|
|
require.NoError(t, err)
|
|
session3, err := manager.LoadLocalSession(ctx, "sess3", "worker id")
|
|
require.NoError(t, err)
|
|
|
|
mockSessionClient.CloseConnectionFn = func(context.Context, *pbs.CloseConnectionRequest) (*pbs.CloseConnectionResponse, error) {
|
|
return nil, errors.New("error")
|
|
}
|
|
assert.False(t, manager.RequestCloseConnections(ctx, map[string]*ConnectionCloseData{"connection id": {SessionId: session1.GetId()}}))
|
|
mockSessionClient.CloseConnectionFn = func(_ context.Context, req *pbs.CloseConnectionRequest) (*pbs.CloseConnectionResponse, error) {
|
|
var data []*pbs.CloseConnectionResponseData
|
|
for _, r := range req.GetCloseRequestData() {
|
|
data = append(data, &pbs.CloseConnectionResponseData{
|
|
ConnectionId: r.GetConnectionId(),
|
|
Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED,
|
|
})
|
|
}
|
|
return &pbs.CloseConnectionResponse{CloseResponseData: data}, nil
|
|
}
|
|
// There is no connection information yet for this connection.
|
|
assert.False(t, manager.RequestCloseConnections(ctx, map[string]*ConnectionCloseData{
|
|
"random_connection_id": {SessionId: session1.GetId()},
|
|
}))
|
|
|
|
// Load the connection info into the local storage
|
|
mockSessionClient.AuthorizeConnectionFn = func(_ context.Context, req *pbs.AuthorizeConnectionRequest) (*pbs.AuthorizeConnectionResponse, error) {
|
|
return &pbs.AuthorizeConnectionResponse{
|
|
ConnectionId: fmt.Sprintf("connection_%s", req.GetSessionId()),
|
|
Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_AUTHORIZED,
|
|
ConnectionsLeft: -1,
|
|
}, nil
|
|
}
|
|
_, cancelFn := context.WithCancel(ctx)
|
|
c1, _, err := session1.RequestAuthorizeConnection(ctx, "worker id", cancelFn)
|
|
require.NoError(t, err)
|
|
|
|
assert.True(t, manager.RequestCloseConnections(ctx, map[string]*ConnectionCloseData{
|
|
c1.GetConnectionId(): {SessionId: session1.GetId()},
|
|
}))
|
|
|
|
c2, _, err := session2.RequestAuthorizeConnection(ctx, "worker id", cancelFn)
|
|
require.NoError(t, err)
|
|
c3, _, err := session3.RequestAuthorizeConnection(ctx, "worker id", cancelFn)
|
|
require.NoError(t, err)
|
|
assert.True(t, manager.RequestCloseConnections(ctx, map[string]*ConnectionCloseData{
|
|
c2.GetConnectionId(): {SessionId: session2.GetId()},
|
|
c3.GetConnectionId(): {SessionId: session3.GetId()},
|
|
}))
|
|
}
|
|
|
|
func TestManager_LoadLocalSession(t *testing.T) {
|
|
mockSessionClient := pbs.NewMockSessionServiceClient()
|
|
errorCases := []struct {
|
|
name string
|
|
inId string
|
|
inWorkerId string
|
|
controllerResponse func() (*pbs.LookupSessionResponse, error)
|
|
wantError string
|
|
}{
|
|
{
|
|
name: "no id",
|
|
inWorkerId: "worker id",
|
|
wantError: "id is not set",
|
|
},
|
|
{
|
|
name: "no worker id",
|
|
inId: "id",
|
|
wantError: "workerId is not set",
|
|
},
|
|
{
|
|
name: "controller error",
|
|
inId: "id",
|
|
inWorkerId: "worker id",
|
|
controllerResponse: func() (*pbs.LookupSessionResponse, error) {
|
|
return nil, errors.New("some error")
|
|
},
|
|
wantError: "some error",
|
|
},
|
|
{
|
|
name: "cant parse cert",
|
|
inId: "id",
|
|
inWorkerId: "worker id",
|
|
controllerResponse: func() (*pbs.LookupSessionResponse, error) {
|
|
return &pbs.LookupSessionResponse{
|
|
Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: "foo",
|
|
Certificate: []byte("unparseable"),
|
|
},
|
|
Version: 1,
|
|
Expiration: timestamppb.New(time.Now().Add(time.Hour)),
|
|
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
|
|
}, nil
|
|
},
|
|
wantError: "error parsing session certificate",
|
|
},
|
|
{
|
|
name: "expired session",
|
|
inId: "id",
|
|
inWorkerId: "worker id",
|
|
controllerResponse: func() (*pbs.LookupSessionResponse, error) {
|
|
return &pbs.LookupSessionResponse{
|
|
Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: "foo",
|
|
Certificate: createTestCert(t),
|
|
},
|
|
Version: 1,
|
|
Expiration: timestamppb.New(time.Now().Add(-time.Hour)),
|
|
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
|
|
}, nil
|
|
},
|
|
wantError: "session is expired",
|
|
},
|
|
}
|
|
for _, tc := range errorCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
mockSessionClient.LookupSessionFn = func(context.Context, *pbs.LookupSessionRequest) (*pbs.LookupSessionResponse, error) {
|
|
return tc.controllerResponse()
|
|
}
|
|
manager, err := NewManager(mockSessionClient)
|
|
require.NoError(t, err)
|
|
s, err := manager.LoadLocalSession(context.Background(), tc.inId, tc.inWorkerId)
|
|
require.Error(t, err)
|
|
assert.Nil(t, s)
|
|
|
|
assert.ErrorContains(t, err, tc.wantError)
|
|
})
|
|
}
|
|
|
|
t.Run("success", func(t *testing.T) {
|
|
expirationTime := time.Now().Add(time.Hour).UTC()
|
|
mockSessionClient.LookupSessionFn = func(context.Context, *pbs.LookupSessionRequest) (*pbs.LookupSessionResponse, error) {
|
|
return &pbs.LookupSessionResponse{
|
|
Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: "foo",
|
|
Certificate: createTestCert(t),
|
|
},
|
|
ConnectionLimit: -1,
|
|
Version: 1,
|
|
Expiration: timestamppb.New(expirationTime),
|
|
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
|
|
}, nil
|
|
}
|
|
manager, err := NewManager(mockSessionClient)
|
|
require.NoError(t, err)
|
|
s, err := manager.LoadLocalSession(context.Background(), "foo", "worker id")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "foo", s.GetId())
|
|
assert.Equal(t, pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING, s.GetStatus())
|
|
assert.Empty(t, s.GetLocalConnections())
|
|
assert.Equal(t, int32(-1), s.GetConnectionLimit())
|
|
assert.Equal(t, expirationTime, s.GetExpiration())
|
|
})
|
|
}
|
|
|
|
func createTestCert(t *testing.T) []byte {
|
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
require.NoError(t, err)
|
|
|
|
template := &x509.Certificate{
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
|
|
SerialNumber: big.NewInt(0),
|
|
NotBefore: time.Now().Add(-30 * time.Second),
|
|
NotAfter: time.Now().Add(5 * time.Minute),
|
|
BasicConstraintsValid: true,
|
|
IsCA: true,
|
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
|
DNSNames: []string{"/tmp/boundary-opslistener-test0.sock", "/tmp/boundary-opslistener-test1.sock"},
|
|
}
|
|
|
|
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, pub, priv)
|
|
require.NoError(t, err)
|
|
|
|
return certBytes
|
|
}
|
|
|
|
func TestWorkerSetCloseTimeForResponse(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
sessionCloseInfo map[string][]*pbs.CloseConnectionResponseData
|
|
sessionInfoMap func() *sync.Map
|
|
expected []string
|
|
expectedClosed map[string]struct{}
|
|
expectedErr []error
|
|
}{
|
|
{
|
|
name: "basic",
|
|
sessionCloseInfo: map[string][]*pbs.CloseConnectionResponseData{
|
|
"one": {
|
|
{ConnectionId: "foo", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
|
|
},
|
|
"two": {
|
|
{ConnectionId: "bar", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
|
|
},
|
|
},
|
|
sessionInfoMap: func() *sync.Map {
|
|
m := new(sync.Map)
|
|
m.Store("one", &sess{
|
|
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: "one",
|
|
}},
|
|
connInfoMap: map[string]*ConnInfo{
|
|
"foo": {Id: "foo"},
|
|
},
|
|
})
|
|
m.Store("two", &sess{
|
|
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: "two",
|
|
}},
|
|
connInfoMap: map[string]*ConnInfo{
|
|
"bar": {Id: "bar"},
|
|
},
|
|
})
|
|
m.Store("three", &sess{
|
|
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: "three",
|
|
}},
|
|
connInfoMap: map[string]*ConnInfo{
|
|
"baz": {Id: "baz"},
|
|
},
|
|
})
|
|
|
|
return m
|
|
},
|
|
expected: []string{"foo", "bar"},
|
|
expectedClosed: map[string]struct{}{
|
|
"foo": {},
|
|
"bar": {},
|
|
},
|
|
},
|
|
{
|
|
name: "not closed",
|
|
sessionCloseInfo: map[string][]*pbs.CloseConnectionResponseData{
|
|
"one": {
|
|
{ConnectionId: "foo", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
|
|
},
|
|
"two": {
|
|
{ConnectionId: "bar", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CONNECTED},
|
|
},
|
|
},
|
|
sessionInfoMap: func() *sync.Map {
|
|
m := new(sync.Map)
|
|
m.Store("one", &sess{
|
|
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: "one",
|
|
}},
|
|
connInfoMap: map[string]*ConnInfo{
|
|
"foo": {Id: "foo"},
|
|
},
|
|
})
|
|
m.Store("two", &sess{
|
|
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: "two",
|
|
}},
|
|
connInfoMap: map[string]*ConnInfo{
|
|
"bar": {Id: "bar"},
|
|
},
|
|
})
|
|
|
|
return m
|
|
},
|
|
expected: []string{"foo"},
|
|
expectedClosed: map[string]struct{}{
|
|
"foo": {},
|
|
},
|
|
},
|
|
{
|
|
name: "missing session",
|
|
sessionCloseInfo: map[string][]*pbs.CloseConnectionResponseData{
|
|
"one": {
|
|
{ConnectionId: "foo", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
|
|
},
|
|
"two": {
|
|
{ConnectionId: "bar", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
|
|
},
|
|
},
|
|
sessionInfoMap: func() *sync.Map {
|
|
m := new(sync.Map)
|
|
m.Store("one", &sess{
|
|
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: "one",
|
|
}},
|
|
connInfoMap: map[string]*ConnInfo{
|
|
"foo": {Id: "foo"},
|
|
},
|
|
})
|
|
|
|
return m
|
|
},
|
|
expected: []string{"foo"},
|
|
expectedClosed: map[string]struct{}{
|
|
"foo": {},
|
|
},
|
|
expectedErr: []error{
|
|
errors.New(`could not find session ID "two" in local state after closing connections`),
|
|
},
|
|
},
|
|
{
|
|
name: "missing connection",
|
|
sessionCloseInfo: map[string][]*pbs.CloseConnectionResponseData{
|
|
"one": {
|
|
{ConnectionId: "foo", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
|
|
},
|
|
"two": {
|
|
{ConnectionId: "bar", Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED},
|
|
},
|
|
},
|
|
sessionInfoMap: func() *sync.Map {
|
|
m := new(sync.Map)
|
|
m.Store("one", &sess{
|
|
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: "one",
|
|
}},
|
|
connInfoMap: map[string]*ConnInfo{
|
|
"foo": {Id: "foo"},
|
|
},
|
|
sessionId: "one",
|
|
})
|
|
m.Store("two", &sess{
|
|
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: "two",
|
|
}},
|
|
sessionId: "two",
|
|
})
|
|
|
|
return m
|
|
},
|
|
expected: []string{"foo"},
|
|
expectedClosed: map[string]struct{}{
|
|
"foo": {},
|
|
},
|
|
expectedErr: []error{
|
|
errors.New(`could not find connection ID "bar" for session ID "two" in local state`),
|
|
},
|
|
},
|
|
{
|
|
name: "empty",
|
|
sessionCloseInfo: make(map[string][]*pbs.CloseConnectionResponseData),
|
|
sessionInfoMap: func() *sync.Map {
|
|
m := new(sync.Map)
|
|
m.Store("one", &sess{
|
|
resp: &pbs.LookupSessionResponse{Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: "one",
|
|
}},
|
|
connInfoMap: map[string]*ConnInfo{
|
|
"foo": {Id: "foo"},
|
|
},
|
|
})
|
|
|
|
return m
|
|
},
|
|
expected: []string{},
|
|
expectedClosed: map[string]struct{}{},
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
tc := tc
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
require := require.New(t)
|
|
manager := &manager{sessionMap: new(sync.Map)}
|
|
manager.sessionMap = tc.sessionInfoMap()
|
|
actual, actualErr := setCloseTimeForResponse(manager, tc.sessionCloseInfo)
|
|
|
|
// Assert all close times were set
|
|
manager.ForEachLocalSession(func(value Session) bool {
|
|
t.Helper()
|
|
for _, ci := range value.GetLocalConnections() {
|
|
if _, ok := tc.expectedClosed[ci.Id]; ok {
|
|
require.NotEqual(time.Time{}, ci.CloseTime)
|
|
} else {
|
|
require.Equal(time.Time{}, ci.CloseTime)
|
|
}
|
|
}
|
|
|
|
return true
|
|
})
|
|
|
|
// Assert return values
|
|
require.ElementsMatch(tc.expected, actual)
|
|
require.ElementsMatch(tc.expectedErr, actualErr)
|
|
})
|
|
}
|
|
}
|