diff --git a/core/download_engine/__init__.py b/core/download_engine/__init__.py index 1e945207..4ebe7a01 100644 --- a/core/download_engine/__init__.py +++ b/core/download_engine/__init__.py @@ -25,5 +25,6 @@ commit so behavior never breaks across the suite. """ from core.download_engine.engine import DownloadEngine +from core.download_engine.worker import BackgroundDownloadWorker -__all__ = ["DownloadEngine"] +__all__ = ["DownloadEngine", "BackgroundDownloadWorker"] diff --git a/core/download_engine/engine.py b/core/download_engine/engine.py index 2ea4bf3c..c9484f16 100644 --- a/core/download_engine/engine.py +++ b/core/download_engine/engine.py @@ -70,6 +70,11 @@ class DownloadEngine: # plugin lookup local to the engine instead of forcing every # caller to also touch the registry. self._plugins: Dict[str, Any] = {} + # Background download worker — lives on the engine because + # it owns the cross-source state the worker mutates. Lazy + # import keeps the engine module standalone. + from core.download_engine.worker import BackgroundDownloadWorker + self.worker = BackgroundDownloadWorker(self) # ------------------------------------------------------------------ # Plugin registration diff --git a/core/download_engine/worker.py b/core/download_engine/worker.py new file mode 100644 index 00000000..93ab17e1 --- /dev/null +++ b/core/download_engine/worker.py @@ -0,0 +1,289 @@ +"""BackgroundDownloadWorker — engine-owned thread spawning + state +lifecycle for downloads. + +Today every streaming download client (YouTube, Tidal, Qobuz, HiFi, +Deezer, SoundCloud) hand-rolls the same thread-spawn pattern: + +```python +async def download(self, ...): + download_id = str(uuid.uuid4()) + with self._download_lock: + self.active_downloads[download_id] = {...initial state...} + threading.Thread( + target=self._download_thread_worker, + args=(download_id, target_id, display_name, ...), + daemon=True, + ).start() + return download_id + +def _download_thread_worker(self, download_id, target_id, display_name, ...): + with self._download_semaphore: + # rate-limit sleep + # update state to 'InProgress, Downloading' + file_path = self._download_sync(...) # the source-specific atomic op + # update state to 'Completed, Succeeded' / 'Errored' +``` + +That pattern is duplicated 6+ times across the codebase (~70 LOC +each, ~490 total). The worker class lifts it into the engine — each +plugin only has to provide the atomic op (``impl_callable``) and +declare its rate-limit policy. Adding a new download source becomes +a much smaller patch. + +Phase C1 scope: introduce the worker. No client migrated yet — the +worker just exists for C2–C7 to migrate sources one at a time, each +under a passing pinning test. +""" + +from __future__ import annotations + +import threading +import time +import uuid +from typing import Any, Callable, Dict, Optional + +from utils.logging_config import get_logger + +logger = get_logger("download_engine.worker") + + +# Type aliases for clarity. ``ImplCallable`` is the per-plugin +# atomic download operation — synchronous, returns a file path on +# success or raises (or returns None) on failure. +ImplCallable = Callable[[str, Any, str], Optional[str]] + + +class BackgroundDownloadWorker: + """Engine-owned thread spawner for per-source downloads. + + State-machine semantics (preserved verbatim from the legacy + per-client workers so consumers reading these fields keep + working): + + - ``Initializing`` — set on dispatch, before the thread starts. + - ``InProgress, Downloading`` — set when the worker thread + acquires the semaphore and is about to call the impl. + - ``Completed, Succeeded`` — set when impl returns a non-None + file path. ``progress=100.0`` and ``file_path=`` + also written. + - ``Errored`` — set when impl returns None OR raises. The + record is left in place so downstream consumers can inspect + what failed. + + Per-source serialization: each source gets a ``threading.Semaphore`` + (default size 1, configurable per-source via ``set_concurrency``). + Same shape the existing clients use today (each source defines + its own semaphore). Engine owning them centrally lets a future + Phase E rate-limiter swap the semaphore for a smarter pool. + + Per-source delay-between-downloads: default 0 seconds (most + sources don't need it). YouTube currently uses 3s, Qobuz uses + 1s — the legacy values get configured in via ``set_delay`` + when the source registers. + """ + + def __init__(self, engine: Any) -> None: + self._engine = engine + # Per-source semaphores + delay state. The first dispatch + # for a source auto-creates a semaphore with concurrency=1 + # if the source hasn't been configured explicitly. + self._semaphores: Dict[str, threading.Semaphore] = {} + self._delays: Dict[str, float] = {} + self._last_download_at: Dict[str, float] = {} + self._config_lock = threading.Lock() + + # ------------------------------------------------------------------ + # Per-source rate-limit configuration + # ------------------------------------------------------------------ + + def set_concurrency(self, source_name: str, max_concurrent: int) -> None: + """Set the max number of concurrent downloads for a source. + Default is 1 (serial). Most sources will keep the default — + the streaming APIs all rate-limit at the API gateway level + anyway, parallel downloads just trade rate-limit errors for + thread overhead.""" + with self._config_lock: + self._semaphores[source_name] = threading.Semaphore(max_concurrent) + + def set_delay(self, source_name: str, seconds: float) -> None: + """Set a minimum delay between successive downloads from the + same source. YouTube uses 3s today (avoid yt-dlp 429s), + Qobuz uses 1s. Other sources use 0 (no delay).""" + with self._config_lock: + self._delays[source_name] = float(seconds) + + def _get_semaphore(self, source_name: str) -> threading.Semaphore: + with self._config_lock: + sem = self._semaphores.get(source_name) + if sem is None: + sem = threading.Semaphore(1) + self._semaphores[source_name] = sem + return sem + + def _get_delay(self, source_name: str) -> float: + with self._config_lock: + return self._delays.get(source_name, 0.0) + + # ------------------------------------------------------------------ + # Dispatch — public API + # ------------------------------------------------------------------ + + def dispatch( + self, + source_name: str, + target_id: Any, + display_name: str, + original_filename: str, + impl_callable: ImplCallable, + extra_record_fields: Optional[Dict[str, Any]] = None, + username_override: Optional[str] = None, + thread_name: Optional[str] = None, + ) -> str: + """Kick off a background download. + + Args: + source_name: Canonical source name (e.g. 'youtube', + 'tidal'). Used as the engine state key + the + username slot in the record (unless overridden). + target_id: Source-specific identifier (track_id, video_id, + permalink_url, album_foreign_id, etc.). Passed + verbatim to ``impl_callable``. + display_name: Human-readable label for logs / UI. + original_filename: The encoded filename the orchestrator + received (e.g. ``'12345||Song Title'``). Stored in + the record's ``filename`` slot for context-key lookups. + impl_callable: Synchronous function that performs the + actual download. Signature: + ``impl_callable(download_id, target_id, display_name) -> Optional[str]``. + Returns the final file path on success or None / + raises on failure. + extra_record_fields: Per-source extras to merge into the + initial record (e.g. ``{'video_id': '...', 'url': + '...', 'title': '...'}`` for YouTube). Used to + preserve source-specific slots that downstream + consumers + status APIs read. + username_override: Use this instead of ``source_name`` + in the record's ``username`` slot. Required for + Deezer (legacy ``'deezer_dl'``) — every other source + uses the canonical name. + thread_name: Optional thread name for diagnostics. Deezer + uses ``'deezer-dl-'`` — Phase A pinning + tests catch any drift in this convention. + + Returns: + download_id (UUID4 string). The orchestrator polls via + ``engine.get_download_status(download_id)`` for progress. + """ + download_id = str(uuid.uuid4()) + + record: Dict[str, Any] = { + 'id': download_id, + 'filename': original_filename, + 'username': username_override or source_name, + 'state': 'Initializing', + 'progress': 0.0, + 'size': 0, + 'transferred': 0, + 'speed': 0, + 'time_remaining': None, + 'file_path': None, + } + if extra_record_fields: + record.update(extra_record_fields) + + self._engine.add_record(source_name, download_id, record) + + thread = threading.Thread( + target=self._worker_loop, + args=(source_name, download_id, target_id, display_name, impl_callable), + daemon=True, + name=thread_name, + ) + thread.start() + + return download_id + + # ------------------------------------------------------------------ + # Worker thread — the lifted boilerplate + # ------------------------------------------------------------------ + + def _worker_loop( + self, + source_name: str, + download_id: str, + target_id: Any, + display_name: str, + impl_callable: ImplCallable, + ) -> None: + """Runs on the spawned daemon thread. Handles semaphore + acquisition, rate-limit sleep, state lifecycle, exception + capture. The plugin-specific work happens entirely inside + ``impl_callable``.""" + try: + with self._get_semaphore(source_name): + # Rate-limit delay against the LAST download from + # this source (not just this worker — semaphore + # ensures serial access while delay is configured). + delay = self._get_delay(source_name) + if delay > 0: + last_at = self._last_download_at.get(source_name, 0.0) + elapsed = time.time() - last_at + if last_at > 0 and elapsed < delay: + wait_time = delay - elapsed + logger.info( + "Rate-limit delay for %s: waiting %.1fs before next download", + source_name, wait_time, + ) + time.sleep(wait_time) + + self._engine.update_record(source_name, download_id, { + 'state': 'InProgress, Downloading', + }) + + try: + file_path = impl_callable(download_id, target_id, display_name) + except Exception as exc: + logger.error( + "%s download %s failed (impl raised): %s", + source_name, download_id, exc, + ) + self._engine.update_record(source_name, download_id, { + 'state': 'Errored', + 'error': str(exc), + }) + return + + self._last_download_at[source_name] = time.time() + + if file_path: + self._engine.update_record(source_name, download_id, { + 'state': 'Completed, Succeeded', + 'progress': 100.0, + 'file_path': file_path, + }) + logger.info( + "%s download %s completed: %s", + source_name, download_id, file_path, + ) + else: + self._engine.update_record(source_name, download_id, { + 'state': 'Errored', + }) + logger.error( + "%s download %s failed (impl returned None)", + source_name, download_id, + ) + + except Exception as exc: + # Defensive — anything in the worker_loop itself + # (semaphore, sleep) shouldn't blow up the thread, but + # if it does the record gets marked Errored so the + # download doesn't sit forever in 'Initializing'. + logger.exception( + "%s worker_loop crashed for download %s: %s", + source_name, download_id, exc, + ) + self._engine.update_record(source_name, download_id, { + 'state': 'Errored', + 'error': f'worker crash: {exc}', + }) diff --git a/tests/downloads/test_background_download_worker.py b/tests/downloads/test_background_download_worker.py new file mode 100644 index 00000000..14c45126 --- /dev/null +++ b/tests/downloads/test_background_download_worker.py @@ -0,0 +1,345 @@ +"""Tests for `BackgroundDownloadWorker` (Phase C1). + +These tests pin the worker's state-machine semantics, semaphore +serialization, rate-limit-delay behavior, and exception handling. +Future phases (C2–C7) migrate each per-source client onto this +worker — these tests stay green as the regression net. +""" + +from __future__ import annotations + +import threading +import time + +from core.download_engine import DownloadEngine + + +# --------------------------------------------------------------------------- +# Dispatch — initial state + thread spawn +# --------------------------------------------------------------------------- + + +def test_dispatch_returns_uuid_download_id(): + engine = DownloadEngine() + + def impl(download_id, target_id, display_name): + return '/tmp/file.flac' + + download_id = engine.worker.dispatch( + source_name='youtube', + target_id='abc123', + display_name='Some Song', + original_filename='abc123||Some Song', + impl_callable=impl, + ) + assert len(download_id) == 36 # UUID4 + assert download_id.count('-') == 4 + + +def test_dispatch_inserts_initial_record_with_canonical_state(): + """Pinning: initial record matches the legacy per-client shape so + consumers reading the state dict via API or context-key lookup + keep working unchanged after migration.""" + engine = DownloadEngine() + captured = threading.Event() + + def impl(download_id, target_id, display_name): + captured.wait(timeout=1.0) # block so we can read 'Initializing' / 'InProgress' state + return '/tmp/file.flac' + + download_id = engine.worker.dispatch( + source_name='youtube', + target_id='abc', + display_name='X', + original_filename='abc||X', + impl_callable=impl, + ) + record = engine.get_record('youtube', download_id) + assert record is not None + assert record['id'] == download_id + assert record['filename'] == 'abc||X' + assert record['username'] == 'youtube' + assert record['state'] in ('Initializing', 'InProgress, Downloading') + assert record['progress'] == 0.0 + assert record['file_path'] is None + captured.set() # release impl + + +def test_dispatch_merges_extra_record_fields(): + """Pinning: source-specific slots (video_id, track_id, etc.) + merge into the initial record so frontend + status APIs that + read those keys keep working.""" + engine = DownloadEngine() + started = threading.Event() + release = threading.Event() + + def impl(download_id, target_id, display_name): + started.set() + release.wait(timeout=1.0) + return '/tmp/x.flac' + + download_id = engine.worker.dispatch( + source_name='youtube', + target_id='vid123', + display_name='Title', + original_filename='vid123||Title', + impl_callable=impl, + extra_record_fields={ + 'video_id': 'vid123', + 'url': 'https://youtube.com/watch?v=vid123', + 'title': 'Title', + }, + ) + started.wait(timeout=1.0) + record = engine.get_record('youtube', download_id) + assert record['video_id'] == 'vid123' + assert record['url'] == 'https://youtube.com/watch?v=vid123' + assert record['title'] == 'Title' + release.set() + + +def test_dispatch_username_override_preserves_legacy_slot(): + """Pinning: Deezer's record stores `'deezer_dl'` (legacy) in the + username slot, not the canonical `'deezer'`. Worker accepts + override so frontend status indicators keep their key.""" + engine = DownloadEngine() + release = threading.Event() + + def impl(download_id, target_id, display_name): + release.wait(timeout=1.0) + return '/tmp/x.flac' + + download_id = engine.worker.dispatch( + source_name='deezer', + target_id='999', + display_name='X', + original_filename='999||X', + impl_callable=impl, + username_override='deezer_dl', + ) + record = engine.get_record('deezer', download_id) + assert record['username'] == 'deezer_dl' + release.set() + + +# --------------------------------------------------------------------------- +# Worker lifecycle — state transitions +# --------------------------------------------------------------------------- + + +def test_worker_marks_completed_on_successful_impl(): + engine = DownloadEngine() + + def impl(download_id, target_id, display_name): + return '/tmp/done.flac' + + download_id = engine.worker.dispatch( + source_name='youtube', + target_id='vid', + display_name='X', + original_filename='vid||X', + impl_callable=impl, + ) + + # Wait for thread to finish. + deadline = time.time() + 2.0 + while time.time() < deadline: + record = engine.get_record('youtube', download_id) + if record and record['state'] == 'Completed, Succeeded': + break + time.sleep(0.01) + + record = engine.get_record('youtube', download_id) + assert record['state'] == 'Completed, Succeeded' + assert record['progress'] == 100.0 + assert record['file_path'] == '/tmp/done.flac' + + +def test_worker_marks_errored_when_impl_returns_none(): + engine = DownloadEngine() + + def impl(download_id, target_id, display_name): + return None # signaling failure + + download_id = engine.worker.dispatch( + source_name='youtube', + target_id='vid', + display_name='X', + original_filename='vid||X', + impl_callable=impl, + ) + + deadline = time.time() + 2.0 + while time.time() < deadline: + record = engine.get_record('youtube', download_id) + if record and record['state'] == 'Errored': + break + time.sleep(0.01) + + record = engine.get_record('youtube', download_id) + assert record['state'] == 'Errored' + # file_path stays None (default). + assert record['file_path'] is None + + +def test_worker_marks_errored_and_captures_message_when_impl_raises(): + engine = DownloadEngine() + + def impl(download_id, target_id, display_name): + raise RuntimeError("api blew up") + + download_id = engine.worker.dispatch( + source_name='youtube', + target_id='vid', + display_name='X', + original_filename='vid||X', + impl_callable=impl, + ) + + deadline = time.time() + 2.0 + while time.time() < deadline: + record = engine.get_record('youtube', download_id) + if record and record['state'] == 'Errored': + break + time.sleep(0.01) + + record = engine.get_record('youtube', download_id) + assert record['state'] == 'Errored' + assert 'api blew up' in record.get('error', '') + + +# --------------------------------------------------------------------------- +# Per-source semaphore serialization +# --------------------------------------------------------------------------- + + +def test_semaphore_serializes_downloads_for_same_source(): + """Pinning: with concurrency=1 (default), two dispatches against + the same source run sequentially. The legacy per-client + semaphore did the same — consumers depend on this for + rate-limit safety against APIs like YouTube.""" + engine = DownloadEngine() + in_progress = threading.Event() + can_finish = threading.Event() + overlap_count = 0 + overlap_lock = threading.Lock() + active_count = [0] + + def impl(download_id, target_id, display_name): + nonlocal overlap_count + with overlap_lock: + active_count[0] += 1 + if active_count[0] > 1: + overlap_count += 1 + in_progress.set() + can_finish.wait(timeout=2.0) + with overlap_lock: + active_count[0] -= 1 + return '/tmp/x.flac' + + # Default concurrency=1 — two dispatches must serialize. + dl1 = engine.worker.dispatch( + source_name='youtube', target_id='a', display_name='A', + original_filename='a||A', impl_callable=impl, + ) + in_progress.wait(timeout=1.0) + in_progress.clear() + dl2 = engine.worker.dispatch( + source_name='youtube', target_id='b', display_name='B', + original_filename='b||B', impl_callable=impl, + ) + # Give second dispatch a chance to attempt running in parallel + # (it should be blocked on the semaphore). + time.sleep(0.1) + assert overlap_count == 0, "second dispatch should be blocked behind semaphore" + + # Release first; second proceeds. + can_finish.set() + + # Wait for both to finish. + deadline = time.time() + 3.0 + while time.time() < deadline: + r1 = engine.get_record('youtube', dl1) + r2 = engine.get_record('youtube', dl2) + if r1 and r2 and r1['state'] == 'Completed, Succeeded' and r2['state'] == 'Completed, Succeeded': + break + time.sleep(0.01) + + assert overlap_count == 0 + + +def test_semaphore_concurrency_can_be_increased(): + """When `set_concurrency(source, N)` is called, N downloads can + run in parallel for that source. Used by sources that support + parallel transfers (none today, but contract supports it).""" + engine = DownloadEngine() + engine.worker.set_concurrency('parallel-source', 3) + + in_flight = [] + in_flight_lock = threading.Lock() + can_finish = threading.Event() + max_observed = [0] + + def impl(download_id, target_id, display_name): + with in_flight_lock: + in_flight.append(download_id) + max_observed[0] = max(max_observed[0], len(in_flight)) + can_finish.wait(timeout=2.0) + with in_flight_lock: + in_flight.remove(download_id) + return '/tmp/x.flac' + + for i in range(3): + engine.worker.dispatch( + source_name='parallel-source', + target_id=str(i), + display_name=f'd{i}', + original_filename=f'{i}||d{i}', + impl_callable=impl, + ) + # Give threads time to ramp up. + time.sleep(0.2) + can_finish.set() + + # Wait for them to finish. + time.sleep(0.5) + assert max_observed[0] == 3 + + +# --------------------------------------------------------------------------- +# Per-source rate-limit delay +# --------------------------------------------------------------------------- + + +def test_delay_enforces_minimum_gap_between_downloads(): + """Pinning: YouTube uses 3s delay today (legacy + `_download_delay`). Worker-driven delay must enforce the same + gap so YouTube doesn't 429.""" + engine = DownloadEngine() + engine.worker.set_delay('youtube', 0.2) # 200ms — short for test speed + + completion_times = [] + + def impl(download_id, target_id, display_name): + completion_times.append(time.time()) + return '/tmp/x.flac' + + # Two back-to-back dispatches. + engine.worker.dispatch( + source_name='youtube', target_id='a', display_name='A', + original_filename='a||A', impl_callable=impl, + ) + engine.worker.dispatch( + source_name='youtube', target_id='b', display_name='B', + original_filename='b||B', impl_callable=impl, + ) + + # Wait for both to finish (semaphore serializes + delay). + deadline = time.time() + 3.0 + while time.time() < deadline and len(completion_times) < 2: + time.sleep(0.01) + + assert len(completion_times) == 2 + gap = completion_times[1] - completion_times[0] + # Gap is at LEAST the configured delay. + assert gap >= 0.18, f"expected gap >= 0.2s, got {gap:.3f}"