C1: Add BackgroundDownloadWorker to engine

`BackgroundDownloadWorker` lives on the engine and owns the
boilerplate every streaming download client currently
hand-rolls: thread spawn, per-source semaphore, rate-limit
delay, state lifecycle (Initializing → InProgress → Completed
or Errored), exception capture.

Plugins provide only the atomic download op (`impl_callable`).
Per-source rate-limit policy (concurrency, delay) is configured
on the worker via `set_concurrency` / `set_delay`. Source-
specific record fields merge in via `extra_record_fields` so
existing consumer code that reads `video_id`, `track_id`,
`permalink_url`, etc. keeps working post-migration. Username
slot supports override (Deezer's legacy `'deezer_dl'`).

Phase C1 scope: worker exists. No client migrated yet — C2-C7
migrate sources one at a time, each gated by the Phase A
pinning tests so per-source contract drift fails fast.

10 new tests pin the worker contract: UUID id format, initial
record shape, extra-fields merge, username override, state
transitions on success / impl-returns-None / impl-raises,
semaphore serialization (default + parallel), rate-limit
delay between successive downloads.

Suite still green (308 download tests). Pure additive.
pull/495/head
Broque Thomas 4 weeks ago
parent 3634dca83f
commit 78724861f9

@ -25,5 +25,6 @@ commit so behavior never breaks across the suite.
"""
from core.download_engine.engine import DownloadEngine
from core.download_engine.worker import BackgroundDownloadWorker
__all__ = ["DownloadEngine"]
__all__ = ["DownloadEngine", "BackgroundDownloadWorker"]

@ -70,6 +70,11 @@ class DownloadEngine:
# plugin lookup local to the engine instead of forcing every
# caller to also touch the registry.
self._plugins: Dict[str, Any] = {}
# Background download worker — lives on the engine because
# it owns the cross-source state the worker mutates. Lazy
# import keeps the engine module standalone.
from core.download_engine.worker import BackgroundDownloadWorker
self.worker = BackgroundDownloadWorker(self)
# ------------------------------------------------------------------
# Plugin registration

@ -0,0 +1,289 @@
"""BackgroundDownloadWorker — engine-owned thread spawning + state
lifecycle for downloads.
Today every streaming download client (YouTube, Tidal, Qobuz, HiFi,
Deezer, SoundCloud) hand-rolls the same thread-spawn pattern:
```python
async def download(self, ...):
download_id = str(uuid.uuid4())
with self._download_lock:
self.active_downloads[download_id] = {...initial state...}
threading.Thread(
target=self._download_thread_worker,
args=(download_id, target_id, display_name, ...),
daemon=True,
).start()
return download_id
def _download_thread_worker(self, download_id, target_id, display_name, ...):
with self._download_semaphore:
# rate-limit sleep
# update state to 'InProgress, Downloading'
file_path = self._download_sync(...) # the source-specific atomic op
# update state to 'Completed, Succeeded' / 'Errored'
```
That pattern is duplicated 6+ times across the codebase (~70 LOC
each, ~490 total). The worker class lifts it into the engine each
plugin only has to provide the atomic op (``impl_callable``) and
declare its rate-limit policy. Adding a new download source becomes
a much smaller patch.
Phase C1 scope: introduce the worker. No client migrated yet the
worker just exists for C2C7 to migrate sources one at a time, each
under a passing pinning test.
"""
from __future__ import annotations
import threading
import time
import uuid
from typing import Any, Callable, Dict, Optional
from utils.logging_config import get_logger
logger = get_logger("download_engine.worker")
# Type aliases for clarity. ``ImplCallable`` is the per-plugin
# atomic download operation — synchronous, returns a file path on
# success or raises (or returns None) on failure.
ImplCallable = Callable[[str, Any, str], Optional[str]]
class BackgroundDownloadWorker:
"""Engine-owned thread spawner for per-source downloads.
State-machine semantics (preserved verbatim from the legacy
per-client workers so consumers reading these fields keep
working):
- ``Initializing`` set on dispatch, before the thread starts.
- ``InProgress, Downloading`` set when the worker thread
acquires the semaphore and is about to call the impl.
- ``Completed, Succeeded`` set when impl returns a non-None
file path. ``progress=100.0`` and ``file_path=<the path>``
also written.
- ``Errored`` set when impl returns None OR raises. The
record is left in place so downstream consumers can inspect
what failed.
Per-source serialization: each source gets a ``threading.Semaphore``
(default size 1, configurable per-source via ``set_concurrency``).
Same shape the existing clients use today (each source defines
its own semaphore). Engine owning them centrally lets a future
Phase E rate-limiter swap the semaphore for a smarter pool.
Per-source delay-between-downloads: default 0 seconds (most
sources don't need it). YouTube currently uses 3s, Qobuz uses
1s the legacy values get configured in via ``set_delay``
when the source registers.
"""
def __init__(self, engine: Any) -> None:
self._engine = engine
# Per-source semaphores + delay state. The first dispatch
# for a source auto-creates a semaphore with concurrency=1
# if the source hasn't been configured explicitly.
self._semaphores: Dict[str, threading.Semaphore] = {}
self._delays: Dict[str, float] = {}
self._last_download_at: Dict[str, float] = {}
self._config_lock = threading.Lock()
# ------------------------------------------------------------------
# Per-source rate-limit configuration
# ------------------------------------------------------------------
def set_concurrency(self, source_name: str, max_concurrent: int) -> None:
"""Set the max number of concurrent downloads for a source.
Default is 1 (serial). Most sources will keep the default
the streaming APIs all rate-limit at the API gateway level
anyway, parallel downloads just trade rate-limit errors for
thread overhead."""
with self._config_lock:
self._semaphores[source_name] = threading.Semaphore(max_concurrent)
def set_delay(self, source_name: str, seconds: float) -> None:
"""Set a minimum delay between successive downloads from the
same source. YouTube uses 3s today (avoid yt-dlp 429s),
Qobuz uses 1s. Other sources use 0 (no delay)."""
with self._config_lock:
self._delays[source_name] = float(seconds)
def _get_semaphore(self, source_name: str) -> threading.Semaphore:
with self._config_lock:
sem = self._semaphores.get(source_name)
if sem is None:
sem = threading.Semaphore(1)
self._semaphores[source_name] = sem
return sem
def _get_delay(self, source_name: str) -> float:
with self._config_lock:
return self._delays.get(source_name, 0.0)
# ------------------------------------------------------------------
# Dispatch — public API
# ------------------------------------------------------------------
def dispatch(
self,
source_name: str,
target_id: Any,
display_name: str,
original_filename: str,
impl_callable: ImplCallable,
extra_record_fields: Optional[Dict[str, Any]] = None,
username_override: Optional[str] = None,
thread_name: Optional[str] = None,
) -> str:
"""Kick off a background download.
Args:
source_name: Canonical source name (e.g. 'youtube',
'tidal'). Used as the engine state key + the
username slot in the record (unless overridden).
target_id: Source-specific identifier (track_id, video_id,
permalink_url, album_foreign_id, etc.). Passed
verbatim to ``impl_callable``.
display_name: Human-readable label for logs / UI.
original_filename: The encoded filename the orchestrator
received (e.g. ``'12345||Song Title'``). Stored in
the record's ``filename`` slot for context-key lookups.
impl_callable: Synchronous function that performs the
actual download. Signature:
``impl_callable(download_id, target_id, display_name) -> Optional[str]``.
Returns the final file path on success or None /
raises on failure.
extra_record_fields: Per-source extras to merge into the
initial record (e.g. ``{'video_id': '...', 'url':
'...', 'title': '...'}`` for YouTube). Used to
preserve source-specific slots that downstream
consumers + status APIs read.
username_override: Use this instead of ``source_name``
in the record's ``username`` slot. Required for
Deezer (legacy ``'deezer_dl'``) every other source
uses the canonical name.
thread_name: Optional thread name for diagnostics. Deezer
uses ``'deezer-dl-<track_id>'`` Phase A pinning
tests catch any drift in this convention.
Returns:
download_id (UUID4 string). The orchestrator polls via
``engine.get_download_status(download_id)`` for progress.
"""
download_id = str(uuid.uuid4())
record: Dict[str, Any] = {
'id': download_id,
'filename': original_filename,
'username': username_override or source_name,
'state': 'Initializing',
'progress': 0.0,
'size': 0,
'transferred': 0,
'speed': 0,
'time_remaining': None,
'file_path': None,
}
if extra_record_fields:
record.update(extra_record_fields)
self._engine.add_record(source_name, download_id, record)
thread = threading.Thread(
target=self._worker_loop,
args=(source_name, download_id, target_id, display_name, impl_callable),
daemon=True,
name=thread_name,
)
thread.start()
return download_id
# ------------------------------------------------------------------
# Worker thread — the lifted boilerplate
# ------------------------------------------------------------------
def _worker_loop(
self,
source_name: str,
download_id: str,
target_id: Any,
display_name: str,
impl_callable: ImplCallable,
) -> None:
"""Runs on the spawned daemon thread. Handles semaphore
acquisition, rate-limit sleep, state lifecycle, exception
capture. The plugin-specific work happens entirely inside
``impl_callable``."""
try:
with self._get_semaphore(source_name):
# Rate-limit delay against the LAST download from
# this source (not just this worker — semaphore
# ensures serial access while delay is configured).
delay = self._get_delay(source_name)
if delay > 0:
last_at = self._last_download_at.get(source_name, 0.0)
elapsed = time.time() - last_at
if last_at > 0 and elapsed < delay:
wait_time = delay - elapsed
logger.info(
"Rate-limit delay for %s: waiting %.1fs before next download",
source_name, wait_time,
)
time.sleep(wait_time)
self._engine.update_record(source_name, download_id, {
'state': 'InProgress, Downloading',
})
try:
file_path = impl_callable(download_id, target_id, display_name)
except Exception as exc:
logger.error(
"%s download %s failed (impl raised): %s",
source_name, download_id, exc,
)
self._engine.update_record(source_name, download_id, {
'state': 'Errored',
'error': str(exc),
})
return
self._last_download_at[source_name] = time.time()
if file_path:
self._engine.update_record(source_name, download_id, {
'state': 'Completed, Succeeded',
'progress': 100.0,
'file_path': file_path,
})
logger.info(
"%s download %s completed: %s",
source_name, download_id, file_path,
)
else:
self._engine.update_record(source_name, download_id, {
'state': 'Errored',
})
logger.error(
"%s download %s failed (impl returned None)",
source_name, download_id,
)
except Exception as exc:
# Defensive — anything in the worker_loop itself
# (semaphore, sleep) shouldn't blow up the thread, but
# if it does the record gets marked Errored so the
# download doesn't sit forever in 'Initializing'.
logger.exception(
"%s worker_loop crashed for download %s: %s",
source_name, download_id, exc,
)
self._engine.update_record(source_name, download_id, {
'state': 'Errored',
'error': f'worker crash: {exc}',
})

@ -0,0 +1,345 @@
"""Tests for `BackgroundDownloadWorker` (Phase C1).
These tests pin the worker's state-machine semantics, semaphore
serialization, rate-limit-delay behavior, and exception handling.
Future phases (C2C7) migrate each per-source client onto this
worker these tests stay green as the regression net.
"""
from __future__ import annotations
import threading
import time
from core.download_engine import DownloadEngine
# ---------------------------------------------------------------------------
# Dispatch — initial state + thread spawn
# ---------------------------------------------------------------------------
def test_dispatch_returns_uuid_download_id():
engine = DownloadEngine()
def impl(download_id, target_id, display_name):
return '/tmp/file.flac'
download_id = engine.worker.dispatch(
source_name='youtube',
target_id='abc123',
display_name='Some Song',
original_filename='abc123||Some Song',
impl_callable=impl,
)
assert len(download_id) == 36 # UUID4
assert download_id.count('-') == 4
def test_dispatch_inserts_initial_record_with_canonical_state():
"""Pinning: initial record matches the legacy per-client shape so
consumers reading the state dict via API or context-key lookup
keep working unchanged after migration."""
engine = DownloadEngine()
captured = threading.Event()
def impl(download_id, target_id, display_name):
captured.wait(timeout=1.0) # block so we can read 'Initializing' / 'InProgress' state
return '/tmp/file.flac'
download_id = engine.worker.dispatch(
source_name='youtube',
target_id='abc',
display_name='X',
original_filename='abc||X',
impl_callable=impl,
)
record = engine.get_record('youtube', download_id)
assert record is not None
assert record['id'] == download_id
assert record['filename'] == 'abc||X'
assert record['username'] == 'youtube'
assert record['state'] in ('Initializing', 'InProgress, Downloading')
assert record['progress'] == 0.0
assert record['file_path'] is None
captured.set() # release impl
def test_dispatch_merges_extra_record_fields():
"""Pinning: source-specific slots (video_id, track_id, etc.)
merge into the initial record so frontend + status APIs that
read those keys keep working."""
engine = DownloadEngine()
started = threading.Event()
release = threading.Event()
def impl(download_id, target_id, display_name):
started.set()
release.wait(timeout=1.0)
return '/tmp/x.flac'
download_id = engine.worker.dispatch(
source_name='youtube',
target_id='vid123',
display_name='Title',
original_filename='vid123||Title',
impl_callable=impl,
extra_record_fields={
'video_id': 'vid123',
'url': 'https://youtube.com/watch?v=vid123',
'title': 'Title',
},
)
started.wait(timeout=1.0)
record = engine.get_record('youtube', download_id)
assert record['video_id'] == 'vid123'
assert record['url'] == 'https://youtube.com/watch?v=vid123'
assert record['title'] == 'Title'
release.set()
def test_dispatch_username_override_preserves_legacy_slot():
"""Pinning: Deezer's record stores `'deezer_dl'` (legacy) in the
username slot, not the canonical `'deezer'`. Worker accepts
override so frontend status indicators keep their key."""
engine = DownloadEngine()
release = threading.Event()
def impl(download_id, target_id, display_name):
release.wait(timeout=1.0)
return '/tmp/x.flac'
download_id = engine.worker.dispatch(
source_name='deezer',
target_id='999',
display_name='X',
original_filename='999||X',
impl_callable=impl,
username_override='deezer_dl',
)
record = engine.get_record('deezer', download_id)
assert record['username'] == 'deezer_dl'
release.set()
# ---------------------------------------------------------------------------
# Worker lifecycle — state transitions
# ---------------------------------------------------------------------------
def test_worker_marks_completed_on_successful_impl():
engine = DownloadEngine()
def impl(download_id, target_id, display_name):
return '/tmp/done.flac'
download_id = engine.worker.dispatch(
source_name='youtube',
target_id='vid',
display_name='X',
original_filename='vid||X',
impl_callable=impl,
)
# Wait for thread to finish.
deadline = time.time() + 2.0
while time.time() < deadline:
record = engine.get_record('youtube', download_id)
if record and record['state'] == 'Completed, Succeeded':
break
time.sleep(0.01)
record = engine.get_record('youtube', download_id)
assert record['state'] == 'Completed, Succeeded'
assert record['progress'] == 100.0
assert record['file_path'] == '/tmp/done.flac'
def test_worker_marks_errored_when_impl_returns_none():
engine = DownloadEngine()
def impl(download_id, target_id, display_name):
return None # signaling failure
download_id = engine.worker.dispatch(
source_name='youtube',
target_id='vid',
display_name='X',
original_filename='vid||X',
impl_callable=impl,
)
deadline = time.time() + 2.0
while time.time() < deadline:
record = engine.get_record('youtube', download_id)
if record and record['state'] == 'Errored':
break
time.sleep(0.01)
record = engine.get_record('youtube', download_id)
assert record['state'] == 'Errored'
# file_path stays None (default).
assert record['file_path'] is None
def test_worker_marks_errored_and_captures_message_when_impl_raises():
engine = DownloadEngine()
def impl(download_id, target_id, display_name):
raise RuntimeError("api blew up")
download_id = engine.worker.dispatch(
source_name='youtube',
target_id='vid',
display_name='X',
original_filename='vid||X',
impl_callable=impl,
)
deadline = time.time() + 2.0
while time.time() < deadline:
record = engine.get_record('youtube', download_id)
if record and record['state'] == 'Errored':
break
time.sleep(0.01)
record = engine.get_record('youtube', download_id)
assert record['state'] == 'Errored'
assert 'api blew up' in record.get('error', '')
# ---------------------------------------------------------------------------
# Per-source semaphore serialization
# ---------------------------------------------------------------------------
def test_semaphore_serializes_downloads_for_same_source():
"""Pinning: with concurrency=1 (default), two dispatches against
the same source run sequentially. The legacy per-client
semaphore did the same consumers depend on this for
rate-limit safety against APIs like YouTube."""
engine = DownloadEngine()
in_progress = threading.Event()
can_finish = threading.Event()
overlap_count = 0
overlap_lock = threading.Lock()
active_count = [0]
def impl(download_id, target_id, display_name):
nonlocal overlap_count
with overlap_lock:
active_count[0] += 1
if active_count[0] > 1:
overlap_count += 1
in_progress.set()
can_finish.wait(timeout=2.0)
with overlap_lock:
active_count[0] -= 1
return '/tmp/x.flac'
# Default concurrency=1 — two dispatches must serialize.
dl1 = engine.worker.dispatch(
source_name='youtube', target_id='a', display_name='A',
original_filename='a||A', impl_callable=impl,
)
in_progress.wait(timeout=1.0)
in_progress.clear()
dl2 = engine.worker.dispatch(
source_name='youtube', target_id='b', display_name='B',
original_filename='b||B', impl_callable=impl,
)
# Give second dispatch a chance to attempt running in parallel
# (it should be blocked on the semaphore).
time.sleep(0.1)
assert overlap_count == 0, "second dispatch should be blocked behind semaphore"
# Release first; second proceeds.
can_finish.set()
# Wait for both to finish.
deadline = time.time() + 3.0
while time.time() < deadline:
r1 = engine.get_record('youtube', dl1)
r2 = engine.get_record('youtube', dl2)
if r1 and r2 and r1['state'] == 'Completed, Succeeded' and r2['state'] == 'Completed, Succeeded':
break
time.sleep(0.01)
assert overlap_count == 0
def test_semaphore_concurrency_can_be_increased():
"""When `set_concurrency(source, N)` is called, N downloads can
run in parallel for that source. Used by sources that support
parallel transfers (none today, but contract supports it)."""
engine = DownloadEngine()
engine.worker.set_concurrency('parallel-source', 3)
in_flight = []
in_flight_lock = threading.Lock()
can_finish = threading.Event()
max_observed = [0]
def impl(download_id, target_id, display_name):
with in_flight_lock:
in_flight.append(download_id)
max_observed[0] = max(max_observed[0], len(in_flight))
can_finish.wait(timeout=2.0)
with in_flight_lock:
in_flight.remove(download_id)
return '/tmp/x.flac'
for i in range(3):
engine.worker.dispatch(
source_name='parallel-source',
target_id=str(i),
display_name=f'd{i}',
original_filename=f'{i}||d{i}',
impl_callable=impl,
)
# Give threads time to ramp up.
time.sleep(0.2)
can_finish.set()
# Wait for them to finish.
time.sleep(0.5)
assert max_observed[0] == 3
# ---------------------------------------------------------------------------
# Per-source rate-limit delay
# ---------------------------------------------------------------------------
def test_delay_enforces_minimum_gap_between_downloads():
"""Pinning: YouTube uses 3s delay today (legacy
`_download_delay`). Worker-driven delay must enforce the same
gap so YouTube doesn't 429."""
engine = DownloadEngine()
engine.worker.set_delay('youtube', 0.2) # 200ms — short for test speed
completion_times = []
def impl(download_id, target_id, display_name):
completion_times.append(time.time())
return '/tmp/x.flac'
# Two back-to-back dispatches.
engine.worker.dispatch(
source_name='youtube', target_id='a', display_name='A',
original_filename='a||A', impl_callable=impl,
)
engine.worker.dispatch(
source_name='youtube', target_id='b', display_name='B',
original_filename='b||B', impl_callable=impl,
)
# Wait for both to finish (semaphore serializes + delay).
deadline = time.time() + 3.0
while time.time() < deadline and len(completion_times) < 2:
time.sleep(0.01)
assert len(completion_times) == 2
gap = completion_times[1] - completion_times[0]
# Gap is at LEAST the configured delay.
assert gap >= 0.18, f"expected gap >= 0.2s, got {gap:.3f}"
Loading…
Cancel
Save