diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index ccc708b3d8..60eb835ae6 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -252,15 +252,6 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, opt .. return errors.Wrap(ctx, err, op) } session.Connections = connections - if session.ProjectId == "" || session.UserId == "" { - // Skip decryption if Project ID or UserId is missing, - // since it will just lead to errors, and the session - // is already canceled if either of those are empty. - return nil - } - if err := decryptAndMaybeUpdateSession(ctx, r.kms, &session, w); err != nil && !opts.withIgnoreDecryptionFailures { - return errors.Wrap(ctx, err, op) - } return nil }, ) @@ -271,6 +262,15 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, opt .. return nil, nil, errors.Wrap(ctx, err, op) } + // Skip decryption if Project ID or UserId is missing, + // since it will just lead to errors, and the session + // is already canceled if either of those are empty. + if session.ProjectId != "" && session.UserId != "" { + if err := decrypt(ctx, r.kms, &session); err != nil && !opts.withIgnoreDecryptionFailures { + return nil, nil, errors.Wrap(ctx, err, op) + } + } + authzSummary, err := r.sessionAuthzSummary(ctx, sessionId) if err != nil { return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg("failed to get authz summary")) @@ -602,7 +602,7 @@ func (r *Repository) lookupActivatedSessionTx(ctx context.Context, reader db.Rea if txErr = reader.LookupById(ctx, activatedSession); txErr != nil { return errors.Wrap(ctx, txErr, op, errors.WithMsg(fmt.Sprintf("failed for %s", sessionId))) } - if txErr = decryptAndMaybeUpdateSession(ctx, r.kms, activatedSession, writer); txErr != nil { + if txErr = decrypt(ctx, r.kms, activatedSession); txErr != nil { return errors.Wrap(ctx, txErr, op) } if len(activatedSession.TofuToken) > 0 && subtle.ConstantTimeCompare(activatedSession.TofuToken, tofuToken) != 1 { @@ -637,10 +637,10 @@ func (r *Repository) getActivatedSession(ctx context.Context, sessionId string, ctx, db.StdRetryCnt, db.ExpBackoff{}, - func(reader db.Reader, w db.Writer) error { - err := r.lookupActivatedSessionTx(ctx, reader, w, sessionId, tofuToken, &activatedSession) + func(reader db.Reader, _ db.Writer) error { + err := reader.LookupById(ctx, &activatedSession) if err != nil { - return errors.Wrap(ctx, err, op) + return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for %s", sessionId))) } returnedStates, err = r.fetchActivatedSessionStatesTx(ctx, reader, sessionId) if err != nil { @@ -652,6 +652,12 @@ func (r *Repository) getActivatedSession(ctx context.Context, sessionId string, if err != nil { return nil, nil, errors.Wrap(ctx, err, op) } + if err := decrypt(ctx, r.kms, &activatedSession); err != nil { + return nil, nil, errors.Wrap(ctx, err, op) + } + if len(activatedSession.TofuToken) > 0 && subtle.ConstantTimeCompare(activatedSession.TofuToken, tofuToken) != 1 { + return nil, nil, errors.New(ctx, errors.TokenMismatch, op, "tofu token mismatch") + } return &activatedSession, returnedStates, nil } @@ -674,10 +680,28 @@ func (r *Repository) ActivateSession(ctx context.Context, sessionId string, sess return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing tofu token") } + // Lookup session first to get the project id so the correct kms wrapper can be used for encrypting the tofu. + foundSession := AllocSession() + foundSession.PublicId = sessionId + if err := r.reader.LookupById(ctx, &foundSession); err != nil { + return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for %s", sessionId))) + } + + // Encrypt the tofu before we start a database transaction to avoid holding the transaction while encrypting. updatedSession := AllocSession() updatedSession.PublicId = sessionId + updatedSession.TofuToken = tofuToken + sessionWrapper, err := r.kms.GetWrapper(ctx, foundSession.ProjectId, kms.KeyPurposeSessions) + if err != nil { + return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get session wrapper")) + } + if err := updatedSession.encrypt(ctx, sessionWrapper); err != nil { + return nil, nil, errors.Wrap(ctx, err, op) + } + + var tofuSeen bool var returnedStates []*State - _, err := r.writer.DoTx( + _, err = r.writer.DoTx( ctx, db.StdRetryCnt, db.ExpBackoff{}, @@ -692,21 +716,26 @@ func (r *Repository) ActivateSession(ctx context.Context, sessionId string, sess if rowsAffected == 0 { return errors.New(ctx, errors.InvalidSessionState, op, "session is not in a pending state") } - foundSession := AllocSession() + + foundSession = AllocSession() foundSession.PublicId = sessionId - err = r.lookupActivatedSessionTx(ctx, reader, w, sessionId, tofuToken, &foundSession) + err = reader.LookupById(ctx, &foundSession) if err != nil { - return errors.Wrap(ctx, err, op) + return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for %s", sessionId))) } - updatedSession.TofuToken = tofuToken - sessionWrapper, err := r.kms.GetWrapper(ctx, foundSession.ProjectId, kms.KeyPurposeSessions) - if err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get session wrapper")) - } - if err := updatedSession.encrypt(ctx, sessionWrapper); err != nil { - return errors.Wrap(ctx, err, op) + // If we already have recorded a tofu, we don't need to update anything. + // Once we are out of the transaction, we can decrypt and check if the + // recorded tofu matches. + if len(foundSession.CtTofuToken) > 0 { + tofuSeen = true + returnedStates, err = r.fetchActivatedSessionStatesTx(ctx, reader, sessionId) + if err != nil { + return errors.Wrap(ctx, err, op) + } + return nil } + rowsUpdated, err := w.Update(ctx, &updatedSession, []string{"CtTofuToken", "KeyId"}, nil) if err != nil { return errors.Wrap(ctx, err, op) @@ -731,6 +760,16 @@ func (r *Repository) ActivateSession(ctx context.Context, sessionId string, sess } return nil, nil, errors.Wrap(ctx, err, op) } + if tofuSeen { + if err := decrypt(ctx, r.kms, &foundSession); err != nil { + return nil, nil, errors.Wrap(ctx, err, op) + } + if subtle.ConstantTimeCompare(foundSession.TofuToken, tofuToken) != 1 { + return nil, nil, errors.New(ctx, errors.TokenMismatch, op, "tofu token mismatch") + } + return &foundSession, returnedStates, nil + } + return &updatedSession, returnedStates, nil } @@ -774,10 +813,6 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV if rowsUpdated != 1 { return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated session and %d rows updated", rowsUpdated)) } - if err := decryptAndMaybeUpdateSession(ctx, r.kms, &updatedSession, w); err != nil && !opts.withIgnoreDecryptionFailures { - return errors.Wrap(ctx, err, op) - } - rowsAffected, err = w.Exec(ctx, updateSessionState, []any{ sql.Named("session_id", sessionId), sql.Named("status", s.String()), @@ -807,6 +842,11 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV if err != nil { return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg("error creating new state")) } + + if err := decrypt(ctx, r.kms, &updatedSession); err != nil && !opts.withIgnoreDecryptionFailures { + return nil, nil, errors.Wrap(ctx, err, op) + } + return &updatedSession, returnedStates, nil } @@ -916,21 +956,15 @@ func fetchHostSetHost(ctx context.Context, r db.Reader, sessionId string, opt .. return hostSetHost, nil } -// decryptAndMaybeUpdateSession switches between the database key and session key based on whether -// the session uses a legacy private key or not. It also updates the encrypted session -// in the database (if necessary). Eventually we should be able to remove this function -// and use the session key unconditionally. -func decryptAndMaybeUpdateSession(ctx context.Context, kmsRepo kms.GetWrapperer, session *Session, writer db.Writer) error { - const op = "session.decryptAndMaybeUpdateSession" +// decrypt decrypts encrypted fields of the Session. +func decrypt(ctx context.Context, kmsRepo kms.GetWrapperer, session *Session) error { + const op = "session.decrypt" if util.IsNil(kmsRepo) { return errors.New(ctx, errors.InvalidParameter, op, "missing kms repo") } if session == nil { return errors.New(ctx, errors.InvalidParameter, op, "missing session") } - if util.IsNil(writer) { - return errors.New(ctx, errors.InvalidParameter, op, "missing writer") - } if session.ProjectId == "" { return errors.New(ctx, errors.InvalidParameter, op, "missing session project ID") } @@ -947,43 +981,8 @@ func decryptAndMaybeUpdateSession(ctx context.Context, kmsRepo kms.GetWrapperer, if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get session wrapper")) } - if len(session.CtCertificatePrivateKey) > 0 { - // New-style session with private key stored in DB, just decrypt. - if err := session.decrypt(ctx, sessionWrapper); err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg("unable to decrypt session value")) - } - return nil - } - // No certificate private key present, this is a legacy session with a - // private key derived from the key. Derive it again and store it back - // in the DB. - _, session.CertificatePrivateKey, err = DeriveED25519Key(ctx, sessionWrapper, session.UserId, session.PublicId) - if err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg("Error deriving session key")) - } - updatedFields := []string{"CtCertificatePrivateKey", "KeyId"} - if len(session.CtTofuToken) > 0 { - // If the TOFU token was set on this session, we can decrypt it using - // the database key. - databaseWrapper, err := kmsRepo.GetWrapper(ctx, session.ProjectId, kms.KeyPurposeDatabase) - if err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper")) - } - if err := session.decrypt(ctx, databaseWrapper); err != nil { - // Note; we can hit this error if the database key that - // was used to encrypt the TOFU token is no longer available - // in the wrapper. Try to return a useful error to the user. - return errors.Wrap(ctx, err, op, errors.WithMsg("unable to decrypt session TOFU token. You may need to recreate your session.")) - } - updatedFields = append(updatedFields, "CtTofuToken") - } - // Rewrap with the session wrapper. Next time we look up this session - // all values will be encrypted using the session key. - if err := session.encrypt(ctx, sessionWrapper); err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg("unable to encrypt session value")) - } - if _, err := writer.Update(ctx, session, updatedFields, nil); err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg("unable to update session value")) + if err := session.decrypt(ctx, sessionWrapper); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to decrypt session value")) } return nil } diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index dd049afd08..b8faff00f8 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -1919,9 +1919,8 @@ func TestRepository_deleteTerminated(t *testing.T) { } } -func Test_decryptAndMaybeUpdateSession(t *testing.T) { +func Test_decrypt(t *testing.T) { conn, _ := db.TestSetup(t, "postgres") - rw := db.New(conn) wrapper := db.TestWrapper(t) kmsRepo := kms.TestKms(t, conn, wrapper) iamRepo := iam.TestRepo(t, conn, wrapper) @@ -1929,89 +1928,41 @@ func Test_decryptAndMaybeUpdateSession(t *testing.T) { t.Run("errors-with-invalid-kms", func(t *testing.T) { s := TestDefaultSession(t, conn, wrapper, iamRepo) - err := decryptAndMaybeUpdateSession(ctx, nil, s, rw) + err := decrypt(ctx, nil, s) require.Error(t, err) }) t.Run("errors-with-invalid-session", func(t *testing.T) { - err := decryptAndMaybeUpdateSession(ctx, kmsRepo, nil, rw) - require.Error(t, err) - }) - t.Run("errors-with-invalid-writer", func(t *testing.T) { - s := TestDefaultSession(t, conn, wrapper, iamRepo) - err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, nil) + err := decrypt(ctx, kmsRepo, nil) require.Error(t, err) }) t.Run("errors-with-invalid-session-project-id", func(t *testing.T) { s := TestDefaultSession(t, conn, wrapper, iamRepo) s.ProjectId = "" - err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw) + err := decrypt(ctx, kmsRepo, s) require.Error(t, err) }) t.Run("errors-with-invalid-session-key-id", func(t *testing.T) { s := TestDefaultSession(t, conn, wrapper, iamRepo) s.KeyId = "" - err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw) + err := decrypt(ctx, kmsRepo, s) require.Error(t, err) }) t.Run("errors-with-invalid-session-user-id", func(t *testing.T) { s := TestDefaultSession(t, conn, wrapper, iamRepo) s.UserId = "" - err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw) + err := decrypt(ctx, kmsRepo, s) require.Error(t, err) }) t.Run("errors-with-invalid-session-public-id", func(t *testing.T) { s := TestDefaultSession(t, conn, wrapper, iamRepo) s.PublicId = "" - err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw) + err := decrypt(ctx, kmsRepo, s) require.Error(t, err) }) t.Run("session-with-local-session-key-succeeds", func(t *testing.T) { s := TestDefaultSession(t, conn, wrapper, iamRepo) - err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw) - require.NoError(t, err) - }) - t.Run("session-with-derived-session-key-succeeds", func(t *testing.T) { - s := TestDefaultSession(t, conn, wrapper, iamRepo) - s.CtCertificatePrivateKey = nil - s.CertificatePrivateKey = nil - s.TofuToken = nil - s.CtTofuToken = nil - err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw) - require.NoError(t, err) - }) - t.Run("session-with-derived-session-key-and-tofu-token-succeeds", func(t *testing.T) { - s := TestDefaultSession(t, conn, wrapper, iamRepo) - s.CtCertificatePrivateKey = nil - s.CertificatePrivateKey = nil - s.TofuToken = []byte("A token") - actualKeyId := s.KeyId - databaseWrapper, err := kmsRepo.GetWrapper(ctx, s.ProjectId, kms.KeyPurposeDatabase) - require.NoError(t, err) - err = s.encrypt(ctx, databaseWrapper) - require.NoError(t, err) - s.KeyId = actualKeyId // Restore this as the encrypt call above will overwrite it. - err = decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw) - require.NoError(t, err) - }) - t.Run("session-with-derived-session-key-and-tofu-token-cannot-be-decrypted", func(t *testing.T) { - s := TestDefaultSession(t, conn, wrapper, iamRepo) - s.CtCertificatePrivateKey = nil - s.CertificatePrivateKey = nil - s.TofuToken = []byte("A token") - actualKeyId := s.KeyId - databaseWrapper, err := kmsRepo.GetWrapper(ctx, s.ProjectId, kms.KeyPurposeDatabase) - require.NoError(t, err) - err = s.encrypt(ctx, databaseWrapper) - require.NoError(t, err) - databaseKeyId := s.KeyId - err = kmsRepo.RotateKeys(ctx, s.ProjectId) - require.NoError(t, err) - ok, err := kmsRepo.DestroyKeyVersion(ctx, s.ProjectId, databaseKeyId) + err := decrypt(ctx, kmsRepo, s) require.NoError(t, err) - assert.True(t, ok) - s.KeyId = actualKeyId // Restore this as the encrypt call above will overwrite it. - err = decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw) - require.ErrorContains(t, err, "You may need to recreate your session") }) } diff --git a/internal/session/rewrapping.go b/internal/session/rewrapping.go index ba79af0a5f..947cf41bd8 100644 --- a/internal/session/rewrapping.go +++ b/internal/session/rewrapping.go @@ -109,7 +109,7 @@ func sessionRewrapFn(ctx context.Context, dataKeyVersionId string, scopeId strin } continue } - if err := decryptAndMaybeUpdateSession(ctx, kmsRepo, session, writer); err != nil { + if err := decrypt(ctx, kmsRepo, session); err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("failed to decrypt session")) } wrapper, err := kmsRepo.GetWrapper(ctx, session.GetProjectId(), kms.KeyPurposeSessions)