fix(target): Assign session connection limit to SessionAuthorization.ConnectionLimit (#5396)

* fix(target): Assign session connection limit to SessionAuthorization.ConnectionLimit

* test(target): Update TestAuthorizeSession & TestAuthorizeSessionTypedCredentials to check session connection limit
pull/5397/head
David Kanney 1 year ago committed by GitHub
parent 157c305753
commit 2342e4b644
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1214,6 +1214,7 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession
HostSetId: hostSetId,
Endpoint: endpointUrl.String(),
Credentials: creds,
ConnectionLimit: t.GetSessionConnectionLimit(),
}
ret.SessionRecordingId, err = SessionRecordingFn(

@ -3547,32 +3547,36 @@ func TestAuthorizeSession(t *testing.T) {
const defaultPort = 2
cases := []struct {
name string
hostSourceId string
credSourceId string
wantedHostId string
wantedEndpoint string
name string
hostSourceId string
credSourceId string
wantedHostId string
wantedEndpoint string
wantedConnectionLimit int32
}{
{
name: "static host",
hostSourceId: shs.GetPublicId(),
credSourceId: clsResp.GetItem().GetId(),
wantedHostId: h.GetPublicId(),
wantedEndpoint: fmt.Sprintf("%s:%d", h.GetAddress(), defaultPort),
name: "static host",
hostSourceId: shs.GetPublicId(),
credSourceId: clsResp.GetItem().GetId(),
wantedHostId: h.GetPublicId(),
wantedEndpoint: fmt.Sprintf("%s:%d", h.GetAddress(), defaultPort),
wantedConnectionLimit: -1,
},
{
name: "static host with port defined",
hostSourceId: shsWithPort.GetPublicId(),
credSourceId: clsResp.GetItem().GetId(),
wantedHostId: hWithPort.GetPublicId(),
wantedEndpoint: fmt.Sprintf("%s:%d", hWithPortBareAddress, defaultPort),
name: "static host with port defined",
hostSourceId: shsWithPort.GetPublicId(),
credSourceId: clsResp.GetItem().GetId(),
wantedHostId: hWithPort.GetPublicId(),
wantedEndpoint: fmt.Sprintf("%s:%d", hWithPortBareAddress, defaultPort),
wantedConnectionLimit: 10,
},
{
name: "plugin host",
hostSourceId: phs.GetPublicId(),
credSourceId: clsResp.GetItem().GetId(),
wantedHostId: "?",
wantedEndpoint: fmt.Sprintf("10.0.0.1:%d", defaultPort),
name: "plugin host",
hostSourceId: phs.GetPublicId(),
credSourceId: clsResp.GetItem().GetId(),
wantedHostId: "?",
wantedEndpoint: fmt.Sprintf("10.0.0.1:%d", defaultPort),
wantedConnectionLimit: 100,
},
}
@ -3588,7 +3592,7 @@ func TestAuthorizeSession(t *testing.T) {
ctx, err = event.NewCorrelationIdContext(ctx, corId)
require.NoError(t, err)
tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), tc.name, target.WithDefaultPort(defaultPort))
tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), tc.name, target.WithDefaultPort(defaultPort), target.WithSessionConnectionLimit(tc.wantedConnectionLimit))
apiTar, err := s.AddTargetHostSources(ctx, &pbs.AddTargetHostSourcesRequest{
Id: tar.GetPublicId(),
Version: tar.GetVersion(),
@ -3660,7 +3664,8 @@ func TestAuthorizeSession(t *testing.T) {
},
},
},
EndpointPort: uint32(defaultPort),
EndpointPort: uint32(defaultPort),
ConnectionLimit: tc.wantedConnectionLimit,
// TODO: validate the contents of the authorization token is what is expected
}
wantSecret := map[string]any{
@ -3991,12 +3996,13 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) {
require.NotNil(t, clsRespSshPrivateKeyWithPassWithMapping)
cases := []struct {
name string
hostSourceId string
credSourceId string
wantedHostId string
wantedEndpoint string
wantedCred *pb.SessionCredential
name string
hostSourceId string
credSourceId string
wantedHostId string
wantedEndpoint string
wantedCred *pb.SessionCredential
wantedConnectionLimit int32
}{
{
name: "vault-unspecified",
@ -4013,6 +4019,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) {
Type: vault.GenericLibrarySubtype.String(),
},
},
wantedConnectionLimit: 1,
},
{
name: "vault-usernamepassword",
@ -4039,6 +4046,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) {
return st
}(),
},
wantedConnectionLimit: 10,
},
{
name: "vault-UsernamePassword-with-mapping",
@ -4065,6 +4073,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) {
return st
}(),
},
wantedConnectionLimit: 100,
},
{
name: "static-UsernamePassword",
@ -4091,6 +4100,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) {
return st
}(),
},
wantedConnectionLimit: 1000,
},
{
name: "static-ssh-private-key",
@ -4117,6 +4127,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) {
return st
}(),
},
wantedConnectionLimit: 10000,
},
{
name: "vault-ssh-private-key",
@ -4143,6 +4154,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) {
return st
}(),
},
wantedConnectionLimit: 50000,
},
{
name: "vault-ssh-private-key-with-mapping",
@ -4169,6 +4181,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) {
return st
}(),
},
wantedConnectionLimit: 100000,
},
{
name: "static-ssh-private-key-with-pass",
@ -4196,6 +4209,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) {
return st
}(),
},
wantedConnectionLimit: 500000,
},
{
name: "vault-ssh-private-key-with-pass",
@ -4223,6 +4237,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) {
return st
}(),
},
wantedConnectionLimit: 1000000,
},
{
name: "vault-ssh-private-key-with-pass-with-mapping",
@ -4250,13 +4265,14 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) {
return st
}(),
},
wantedConnectionLimit: -1,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
const defaultPort = 2
tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), tc.name, target.WithDefaultPort(defaultPort))
tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), tc.name, target.WithDefaultPort(defaultPort), target.WithSessionConnectionLimit(tc.wantedConnectionLimit))
apiTar, err := s.AddTargetHostSources(ctx, &pbs.AddTargetHostSourcesRequest{
Id: tar.GetPublicId(),
Version: tar.GetVersion(),
@ -4293,15 +4309,16 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) {
Description: proj.GetDescription(),
ParentScopeId: proj.GetParentId(),
},
TargetId: tar.GetPublicId(),
UserId: at.GetIamUserId(),
HostSetId: tc.hostSourceId,
HostId: tc.wantedHostId,
Type: "tcp",
Endpoint: fmt.Sprintf("tcp://%s:%d", tc.wantedEndpoint, defaultPort),
Credentials: []*pb.SessionCredential{tc.wantedCred},
EndpointPort: uint32(defaultPort),
Expiration: asRes.Item.Expiration,
TargetId: tar.GetPublicId(),
UserId: at.GetIamUserId(),
HostSetId: tc.hostSourceId,
HostId: tc.wantedHostId,
Type: "tcp",
Endpoint: fmt.Sprintf("tcp://%s:%d", tc.wantedEndpoint, defaultPort),
Credentials: []*pb.SessionCredential{tc.wantedCred},
EndpointPort: uint32(defaultPort),
Expiration: asRes.Item.Expiration,
ConnectionLimit: tc.wantedConnectionLimit,
// TODO: validate the contents of the authorization token is what is expected
}
got := asRes.GetItem()

Loading…
Cancel
Save