From b6849b1c15bf25af703805d9e979159a65e093f4 Mon Sep 17 00:00:00 2001 From: Irena Rindos Date: Mon, 15 Sep 2025 13:42:47 -0400 Subject: [PATCH] fix(sess): grab kms wrapper before tx (#6052) --- .../handlers/targets/target_service.go | 2 +- internal/target/repository.go | 20 ++++++++++--------- ...epository_proxy_server_certificate_test.go | 2 +- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/internal/daemon/controller/handlers/targets/target_service.go b/internal/daemon/controller/handlers/targets/target_service.go index df0184acae..f7ae111005 100644 --- a/internal/daemon/controller/handlers/targets/target_service.go +++ b/internal/daemon/controller/handlers/targets/target_service.go @@ -805,7 +805,7 @@ func (s Service) AuthorizeSession(ctx context.Context, req *pbs.AuthorizeSession if err != nil { return nil, err } - t, err := repo.LookupTargetForSessionAuthorization(ctx, roundTripTarget.GetPublicId(), target.WithAlias(targetAlias)) + t, err := repo.LookupTargetForSessionAuthorization(ctx, roundTripTarget.GetPublicId(), roundTripTarget.GetProjectId(), target.WithAlias(targetAlias)) if err != nil { if errors.IsNotFoundError(err) { return nil, handlers.NotFoundErrorf("Target %q not found.", roundTripTarget.GetPublicId()) diff --git a/internal/target/repository.go b/internal/target/repository.go index 21e455a7a8..bd8b7aa035 100644 --- a/internal/target/repository.go +++ b/internal/target/repository.go @@ -93,12 +93,19 @@ func NewRepository(ctx context.Context, r db.Reader, w db.Writer, kms *kms.Kms, // with its host source ids, credential source ids, and server certificate, if applicable. If the target is not // found, it will return nil, nil. // Supported option: WithAlias if the session authorization uses a target alias -func (r *Repository) LookupTargetForSessionAuthorization(ctx context.Context, publicId string, opt ...Option) (Target, error) { +func (r *Repository) LookupTargetForSessionAuthorization(ctx context.Context, publicId string, projectId string, opt ...Option) (Target, error) { const op = "target.(Repository).LookupTargetForSessionAuthorization" opts := GetOpts(opt...) - - if publicId == "" { + switch { + case publicId == "": return nil, errors.New(ctx, errors.InvalidParameter, op, "missing public id") + case projectId == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing project id") + } + + databaseWrapper, err := r.kms.GetWrapper(ctx, projectId, kms.KeyPurposeDatabase) + if err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper")) } target := allocTargetView() @@ -107,7 +114,7 @@ func (r *Repository) LookupTargetForSessionAuthorization(ctx context.Context, pu var hostSources []HostSource var credSources []CredentialSource var cert *ServerCertificate - _, err := r.writer.DoTx( + _, err = r.writer.DoTx( ctx, db.StdRetryCnt, db.ExpBackoff{}, @@ -133,11 +140,6 @@ func (r *Repository) LookupTargetForSessionAuthorization(ctx context.Context, pu address = targetAddress.GetAddress() } - databaseWrapper, err := r.kms.GetWrapper(ctx, target.GetProjectId(), kms.KeyPurposeDatabase) - if err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper")) - } - if opts.WithAlias != nil { cert, err = fetchTargetAliasProxyServerCertificate(ctx, read, w, target.PublicId, target.ProjectId, opts.WithAlias, databaseWrapper, target.GetSessionMaxSeconds()) if err != nil && !errors.IsNotFoundError(err) { diff --git a/internal/target/repository_proxy_server_certificate_test.go b/internal/target/repository_proxy_server_certificate_test.go index ea312403a7..041106a16b 100644 --- a/internal/target/repository_proxy_server_certificate_test.go +++ b/internal/target/repository_proxy_server_certificate_test.go @@ -372,7 +372,7 @@ func Test_FetchCertsWithinLookupTargetForSessionAuthorization(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) - got, err := repo.LookupTargetForSessionAuthorization(ctx, tt.publicId, tt.opt...) + got, err := repo.LookupTargetForSessionAuthorization(ctx, tt.publicId, proj.PublicId, tt.opt...) require.NoError(err) assert.NotNil(got) if tt.wantCert {