diff --git a/internal/db/schema/migrations/oss/postgres/98/06_target_credential_injection_certificates.up.sql b/internal/db/schema/migrations/oss/postgres/98/06_target_credential_injection_certificates.up.sql index 28f9bbe56d..cfeec37e21 100644 --- a/internal/db/schema/migrations/oss/postgres/98/06_target_credential_injection_certificates.up.sql +++ b/internal/db/schema/migrations/oss/postgres/98/06_target_credential_injection_certificates.up.sql @@ -137,4 +137,22 @@ create table session_proxy_certificate( comment on table session_proxy_certificate is 'session_proxy_certificate is a table where each row maps a certificate to a session id.'; +-- delete_session_proxy_certificate_with_null_fk is intended to be a before update trigger that +-- deletes the session_proxy_certificate entry for a session if the project_id of the session is being set to null. +-- If the project_id is null, this indicates that the kms key used to encrypt the session_proxy_certificate is being deleted +-- and therefore the session_proxy_certificate entry must also be deleted. +create or replace function delete_session_proxy_certificate_with_null_fk() returns trigger +as $$ +begin + case + when new.project_id is null then + delete from session_proxy_certificate where session_id = new.public_id; + end case; +return new; +end; +$$ language plpgsql; + +create trigger delete_session_proxy_certificate_with_null_fk before update of project_id on session + for each row execute procedure delete_session_proxy_certificate_with_null_fk(); + commit; \ No newline at end of file diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index ba3032fdd6..345daa024c 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -5,6 +5,7 @@ package session import ( "context" + "database/sql" "testing" "time" @@ -2004,3 +2005,119 @@ func TestRepository_LookupProxyCertificate(t *testing.T) { }) } } + +func TestRepository_ProxyCertificateDeletion(t *testing.T) { + ctx := t.Context() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + kmsCache := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(ctx, rw, rw, kmsCache) + require.NoError(t, err) + iamRepo := iam.TestRepo(t, conn, wrapper) + + t.Run("Deleting the project deletes the cert", func(t *testing.T) { + org, proj := iam.TestScopes(t, iamRepo) + kmsWrapper, err := kmsCache.GetWrapper(context.Background(), proj.PublicId, kms.KeyPurposeSessions) + require.NoError(t, err) + + at := authtoken.TestAuthToken(t, conn, kmsCache, org.GetPublicId()) + uId := at.GetIamUserId() + hc := static.TestCatalogs(t, conn, proj.GetPublicId(), 1)[0] + hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0] + h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0] + static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h}) + tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), "test", target.WithHostSources([]string{hs.GetPublicId()})) + session := TestSession(t, conn, wrapper, ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ProjectId: proj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + }) + + privKeyValue := []byte("fake-private-key") + certValue := []byte("fake-cert-value") + cert, err := NewProxyCertificate(ctx, session.PublicId, privKeyValue, certValue) + require.NoError(t, err) + require.NotNil(t, cert) + + err = cert.Encrypt(ctx, kmsWrapper) + require.NoError(t, err) + + err = rw.Create(ctx, cert) + require.NoError(t, err) + + // Check that the cert is there + got, err := repo.LookupProxyCertificate(ctx, proj.GetPublicId(), session.PublicId) + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, got.PrivateKey, privKeyValue) + assert.Equal(t, got.Certificate, certValue) + + n, err := iamRepo.DeleteScope(ctx, proj.GetPublicId()) + require.NoError(t, err) + require.Equal(t, n, 1) + got, err = repo.LookupProxyCertificate(ctx, proj.GetPublicId(), session.PublicId) + require.NoError(t, err) + require.Nil(t, got) + }) + + // Ensure that if project ID is set to null on a session for a reason other than the scope being + // removed that the cert is still removed + t.Run("Null project id on session deletes the cert", func(t *testing.T) { + org, proj := iam.TestScopes(t, iamRepo) + kmsWrapper, err := kmsCache.GetWrapper(context.Background(), proj.PublicId, kms.KeyPurposeSessions) + require.NoError(t, err) + + at := authtoken.TestAuthToken(t, conn, kmsCache, org.GetPublicId()) + uId := at.GetIamUserId() + hc := static.TestCatalogs(t, conn, proj.GetPublicId(), 1)[0] + hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0] + h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0] + static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h}) + tar := tcp.TestTarget(ctx, t, conn, proj.GetPublicId(), "test", target.WithHostSources([]string{hs.GetPublicId()})) + session := TestSession(t, conn, wrapper, ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ProjectId: proj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + }) + + privKeyValue := []byte("fake-private-key") + certValue := []byte("fake-cert-value") + cert, err := NewProxyCertificate(ctx, session.PublicId, privKeyValue, certValue) + require.NoError(t, err) + require.NotNil(t, cert) + + err = cert.Encrypt(ctx, kmsWrapper) + require.NoError(t, err) + err = rw.Create(ctx, cert) + require.NoError(t, err) + + // Check that the cert is there + got, err := repo.LookupProxyCertificate(ctx, proj.GetPublicId(), session.PublicId) + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, got.PrivateKey, privKeyValue) + assert.Equal(t, got.Certificate, certValue) + + query := `update session set project_id = null where + public_id = @session_id` + + rowsAffected, err := rw.Exec(ctx, query, []any{ + sql.Named("session_id", session.PublicId), + }) + require.NoError(t, err) + require.Equal(t, 1, rowsAffected) + + got, err = repo.LookupProxyCertificate(ctx, proj.GetPublicId(), session.PublicId) + require.NoError(t, err) + require.Nil(t, got) + }) +}