diff --git a/internal/daemon/controller/handlers/targets/target_service.go b/internal/daemon/controller/handlers/targets/target_service.go index f6e92d28d1..e269244d88 100644 --- a/internal/daemon/controller/handlers/targets/target_service.go +++ b/internal/daemon/controller/handlers/targets/target_service.go @@ -1214,6 +1214,7 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession HostSetId: hostSetId, Endpoint: endpointUrl.String(), Credentials: creds, + ConnectionLimit: t.GetSessionConnectionLimit(), } ret.SessionRecordingId, err = SessionRecordingFn( diff --git a/internal/daemon/controller/handlers/targets/tcp/target_service_test.go b/internal/daemon/controller/handlers/targets/tcp/target_service_test.go index 836c6a28b7..99d9410686 100644 --- a/internal/daemon/controller/handlers/targets/tcp/target_service_test.go +++ b/internal/daemon/controller/handlers/targets/tcp/target_service_test.go @@ -3547,32 +3547,36 @@ func TestAuthorizeSession(t *testing.T) { const defaultPort = 2 cases := []struct { - name string - hostSourceId string - credSourceId string - wantedHostId string - wantedEndpoint string + name string + hostSourceId string + credSourceId string + wantedHostId string + wantedEndpoint string + wantedConnectionLimit int32 }{ { - name: "static host", - hostSourceId: shs.GetPublicId(), - credSourceId: clsResp.GetItem().GetId(), - wantedHostId: h.GetPublicId(), - wantedEndpoint: fmt.Sprintf("%s:%d", h.GetAddress(), defaultPort), + name: "static host", + hostSourceId: shs.GetPublicId(), + credSourceId: clsResp.GetItem().GetId(), + wantedHostId: h.GetPublicId(), + wantedEndpoint: fmt.Sprintf("%s:%d", h.GetAddress(), defaultPort), + wantedConnectionLimit: -1, }, { - name: "static host with port defined", - hostSourceId: shsWithPort.GetPublicId(), - credSourceId: clsResp.GetItem().GetId(), - wantedHostId: hWithPort.GetPublicId(), - wantedEndpoint: fmt.Sprintf("%s:%d", hWithPortBareAddress, defaultPort), + name: "static host with port defined", + hostSourceId: shsWithPort.GetPublicId(), + credSourceId: clsResp.GetItem().GetId(), + wantedHostId: hWithPort.GetPublicId(), + wantedEndpoint: fmt.Sprintf("%s:%d", hWithPortBareAddress, defaultPort), + wantedConnectionLimit: 10, }, { - name: "plugin host", - hostSourceId: phs.GetPublicId(), - credSourceId: clsResp.GetItem().GetId(), - wantedHostId: "?", - wantedEndpoint: fmt.Sprintf("10.0.0.1:%d", defaultPort), + name: "plugin host", + hostSourceId: phs.GetPublicId(), + credSourceId: clsResp.GetItem().GetId(), + wantedHostId: "?", + wantedEndpoint: fmt.Sprintf("10.0.0.1:%d", defaultPort), + wantedConnectionLimit: 100, }, } @@ -3588,7 +3592,7 @@ func TestAuthorizeSession(t *testing.T) { ctx, err = event.NewCorrelationIdContext(ctx, corId) require.NoError(t, err) - tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), tc.name, target.WithDefaultPort(defaultPort)) + tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), tc.name, target.WithDefaultPort(defaultPort), target.WithSessionConnectionLimit(tc.wantedConnectionLimit)) apiTar, err := s.AddTargetHostSources(ctx, &pbs.AddTargetHostSourcesRequest{ Id: tar.GetPublicId(), Version: tar.GetVersion(), @@ -3660,7 +3664,8 @@ func TestAuthorizeSession(t *testing.T) { }, }, }, - EndpointPort: uint32(defaultPort), + EndpointPort: uint32(defaultPort), + ConnectionLimit: tc.wantedConnectionLimit, // TODO: validate the contents of the authorization token is what is expected } wantSecret := map[string]any{ @@ -3991,12 +3996,13 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { require.NotNil(t, clsRespSshPrivateKeyWithPassWithMapping) cases := []struct { - name string - hostSourceId string - credSourceId string - wantedHostId string - wantedEndpoint string - wantedCred *pb.SessionCredential + name string + hostSourceId string + credSourceId string + wantedHostId string + wantedEndpoint string + wantedCred *pb.SessionCredential + wantedConnectionLimit int32 }{ { name: "vault-unspecified", @@ -4013,6 +4019,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { Type: vault.GenericLibrarySubtype.String(), }, }, + wantedConnectionLimit: 1, }, { name: "vault-usernamepassword", @@ -4039,6 +4046,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { return st }(), }, + wantedConnectionLimit: 10, }, { name: "vault-UsernamePassword-with-mapping", @@ -4065,6 +4073,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { return st }(), }, + wantedConnectionLimit: 100, }, { name: "static-UsernamePassword", @@ -4091,6 +4100,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { return st }(), }, + wantedConnectionLimit: 1000, }, { name: "static-ssh-private-key", @@ -4117,6 +4127,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { return st }(), }, + wantedConnectionLimit: 10000, }, { name: "vault-ssh-private-key", @@ -4143,6 +4154,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { return st }(), }, + wantedConnectionLimit: 50000, }, { name: "vault-ssh-private-key-with-mapping", @@ -4169,6 +4181,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { return st }(), }, + wantedConnectionLimit: 100000, }, { name: "static-ssh-private-key-with-pass", @@ -4196,6 +4209,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { return st }(), }, + wantedConnectionLimit: 500000, }, { name: "vault-ssh-private-key-with-pass", @@ -4223,6 +4237,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { return st }(), }, + wantedConnectionLimit: 1000000, }, { name: "vault-ssh-private-key-with-pass-with-mapping", @@ -4250,13 +4265,14 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { return st }(), }, + wantedConnectionLimit: -1, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { const defaultPort = 2 - tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), tc.name, target.WithDefaultPort(defaultPort)) + tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), tc.name, target.WithDefaultPort(defaultPort), target.WithSessionConnectionLimit(tc.wantedConnectionLimit)) apiTar, err := s.AddTargetHostSources(ctx, &pbs.AddTargetHostSourcesRequest{ Id: tar.GetPublicId(), Version: tar.GetVersion(), @@ -4293,15 +4309,16 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { Description: proj.GetDescription(), ParentScopeId: proj.GetParentId(), }, - TargetId: tar.GetPublicId(), - UserId: at.GetIamUserId(), - HostSetId: tc.hostSourceId, - HostId: tc.wantedHostId, - Type: "tcp", - Endpoint: fmt.Sprintf("tcp://%s:%d", tc.wantedEndpoint, defaultPort), - Credentials: []*pb.SessionCredential{tc.wantedCred}, - EndpointPort: uint32(defaultPort), - Expiration: asRes.Item.Expiration, + TargetId: tar.GetPublicId(), + UserId: at.GetIamUserId(), + HostSetId: tc.hostSourceId, + HostId: tc.wantedHostId, + Type: "tcp", + Endpoint: fmt.Sprintf("tcp://%s:%d", tc.wantedEndpoint, defaultPort), + Credentials: []*pb.SessionCredential{tc.wantedCred}, + EndpointPort: uint32(defaultPort), + Expiration: asRes.Item.Expiration, + ConnectionLimit: tc.wantedConnectionLimit, // TODO: validate the contents of the authorization token is what is expected } got := asRes.GetItem()