You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
SoulSync/core/download_engine/engine.py

458 lines
21 KiB

"""DownloadEngine — central owner of cross-source download state.
Phase B scope: skeleton only. The engine exposes a place for
plugins to register, a single ``active_downloads`` dict keyed by
``(source, download_id)``, and per-source RLocks that guard mutations
without serializing workers across different sources.
Subsequent phases bolt more capability on top:
- ``dispatch_download(plugin, target_id)`` (Phase C — replaces every
client's ``_download_thread_worker`` boilerplate).
- ``search(query, source_chain)`` (Phase D — replaces every client's
retry ladder + quality filter).
- ``rate_limit.acquire(source)`` (Phase E — replaces every client's
semaphore + last-download-timestamp dance).
- ``search_with_fallback`` / ``download_with_fallback`` (Phase F —
unifies hybrid mode across search and download).
The engine is constructed by ``DownloadOrchestrator.__init__`` and
each plugin from the registry is registered with it. In Phase B
nothing in the existing code paths goes through the engine yet —
this commit is pure additive scaffolding so subsequent commits can
introduce engine-driven behavior one piece at a time without a
big-bang switchover.
"""
from __future__ import annotations
import threading
from typing import Any, Dict, Iterator, List, Optional, Tuple
from utils.logging_config import get_logger
logger = get_logger("download_engine")
# Type alias for the per-download state dict. Today's clients each
# define their own slightly-different shape (see Phase A pinning
# tests); the engine stores them as opaque dicts and the per-plugin
# accessor preserves the source-specific fields.
DownloadRecord = Dict[str, Any]
class DownloadEngine:
"""Central state for every active download across every source.
State is keyed by ``(source_name, download_id)`` so the same
UUID could hypothetically appear in two sources without
collision (in practice each source generates its own UUID4
so collisions are negligible — the source qualifier exists
so the engine can answer "which plugin owns this download" in
O(1) without iterating every plugin).
Thread safety: per-source lock sharding. Each source gets its own
RLock — progress callbacks on Deezer don't block Tidal's worker
and vice versa, matching the pre-refactor behavior where each
client owned its own download lock. Read-only accessors
(``get_record``, ``iter_records_for_source``) take the source's
lock briefly and return a SHALLOW COPY so the caller can iterate
without holding the lock. Callers that need to mutate a record
should use ``update_record`` which takes the lock and applies the
patch atomically.
"""
def __init__(self) -> None:
# Nested dict: source_name → {download_id → record}. Replaces
# the original single-dict composite-key layout so
# ``iter_records_for_source`` is O(source_records) instead of
# O(total_records).
self._records: Dict[str, Dict[str, DownloadRecord]] = {}
# Per-source RLocks. Each source gets its own so progress
# updates on one source never block writes on another. RLock
# so a plugin's worker callback can re-enter while holding the
# lock for its own update. Lazily created via ``_source_lock``;
# the meta-lock guards creation against the create-race window
# where two threads could both miss + both create.
self._source_locks: Dict[str, threading.RLock] = {}
self._source_locks_lock = threading.Lock()
# Plugins that have registered with the engine. Source name
# → plugin instance.
self._plugins: Dict[str, Any] = {}
# Alias → canonical-name map. Lets engine resolve legacy
# source-name strings (e.g. ``'deezer_dl'`` for Deezer) to
# the canonical key in ``_plugins``. Cin's review caught
# that engine.cancel_download(source_hint='deezer_dl')
# silently fell through to Soulseek because alias resolution
# only existed at the registry, not on the engine.
self._aliases: Dict[str, str] = {}
# Background download worker — lives on the engine because
# it owns the cross-source state the worker mutates. Lazy
# import keeps the engine module standalone.
from core.download_engine.worker import BackgroundDownloadWorker
self.worker = BackgroundDownloadWorker(self)
# ------------------------------------------------------------------
# Plugin registration
# ------------------------------------------------------------------
def register_plugin(self, source_name: str, plugin: Any,
aliases: Tuple[str, ...] = ()) -> None:
"""Register a plugin under its canonical source name. Called
once per source by the orchestrator after the registry's
``initialize`` builds the client instances.
``aliases`` is the list of legacy source-name strings that
should resolve to this plugin (e.g. ``'deezer_dl'`` for
Deezer). Without alias resolution the engine couldn't route
cancel/lookup calls that came in with the legacy name.
If the plugin exposes ``set_engine(engine)``, the engine
passes a self-reference so the plugin can dispatch into
``engine.worker`` / read state / etc. Plugins that haven't
been migrated to the engine yet simply don't define
``set_engine`` — they keep their pre-engine behavior
unchanged.
Also reads the plugin's declared ``RateLimitPolicy`` (via
the ``rate_limit_policy()`` method or ``RATE_LIMIT_POLICY``
class attribute) and applies it to the worker. Plugins that
don't declare a policy get the conservative default
(concurrency=1, delay=0).
"""
if source_name in self._plugins:
logger.warning("Plugin %s already registered with engine — overwriting", source_name)
self._plugins[source_name] = plugin
for alias in aliases:
self._aliases[alias] = source_name
# Apply the plugin's rate-limit policy BEFORE set_engine so
# set_engine callbacks can override per-source if they need
# config-driven values (e.g. YouTube's user-tunable delay).
from core.download_engine.rate_limit import resolve_policy
policy = resolve_policy(plugin)
self.worker.set_concurrency(source_name, policy.download_concurrency)
self.worker.set_delay(source_name, policy.download_delay_seconds)
set_engine = getattr(plugin, 'set_engine', None)
if callable(set_engine):
try:
set_engine(self)
except Exception as exc:
logger.warning(
"Plugin %s set_engine callback failed: %s", source_name, exc,
)
def get_plugin(self, source_name: str) -> Optional[Any]:
"""Return the plugin instance for the given source name.
Resolves through aliases — e.g. ``get_plugin('deezer_dl')``
returns the same instance as ``get_plugin('deezer')``."""
if source_name in self._plugins:
return self._plugins[source_name]
canonical = self._aliases.get(source_name)
if canonical:
return self._plugins.get(canonical)
return None
def _resolve_canonical(self, source_name: str) -> Optional[str]:
"""Return the canonical source name for an input that may be
an alias. Returns None if the input matches neither a
canonical name nor an alias."""
if source_name in self._plugins:
return source_name
return self._aliases.get(source_name)
def registered_sources(self) -> List[str]:
return list(self._plugins.keys())
def _source_lock(self, source_name: str) -> threading.RLock:
"""Return the per-source RLock, lazy-creating it on first use.
The meta-lock around the cache lookup closes the create-race
window where two threads both miss + both create a fresh lock.
"""
with self._source_locks_lock:
lock = self._source_locks.get(source_name)
if lock is None:
lock = threading.RLock()
self._source_locks[source_name] = lock
return lock
# ------------------------------------------------------------------
# Active-downloads state — Phase B core surface
# ------------------------------------------------------------------
def add_record(self, source_name: str, download_id: str, record: DownloadRecord) -> None:
"""Insert a fresh download record. Used by clients (today
directly via their own dicts; Phase B2 routes them through
here)."""
with self._source_lock(source_name):
source_bucket = self._records.setdefault(source_name, {})
if download_id in source_bucket:
logger.warning("Replacing existing download record for %s/%s", source_name, download_id)
source_bucket[download_id] = dict(record)
def update_record(self, source_name: str, download_id: str, patch: DownloadRecord) -> None:
"""Apply a partial patch to an existing record. No-op if the
record was already removed (e.g. cancelled mid-update)."""
with self._source_lock(source_name):
existing = self._records.get(source_name, {}).get(download_id)
if existing is None:
return
existing.update(patch)
def update_record_unless_state(self, source_name: str, download_id: str,
patch: DownloadRecord,
skip_if_state_in: Tuple[str, ...] = ()) -> bool:
"""Atomically check the record's state and apply ``patch`` only
if the current state is NOT in ``skip_if_state_in``. Returns
True if the patch was applied, False if it was skipped (or
the record didn't exist).
Used by the background download worker's ``_mark_terminal``
to avoid the read-then-write race Cin flagged: a cancel
landing between the snapshot and update could be overwritten
back to Errored / Completed. Holding the source's lock across
the check + write closes the window.
"""
with self._source_lock(source_name):
existing = self._records.get(source_name, {}).get(download_id)
if existing is None:
return False
if existing.get('state') in skip_if_state_in:
return False
existing.update(patch)
return True
def remove_record(self, source_name: str, download_id: str) -> Optional[DownloadRecord]:
"""Delete a record (cancellation cleanup). Returns the
removed record or None if not found."""
with self._source_lock(source_name):
source_bucket = self._records.get(source_name)
if not source_bucket:
return None
removed = source_bucket.pop(download_id, None)
# Drop the empty source bucket so iteration / membership
# checks don't see a stale source key.
if not source_bucket:
self._records.pop(source_name, None)
return removed
def get_record(self, source_name: str, download_id: str) -> Optional[DownloadRecord]:
"""Return a SHALLOW COPY of the record. Caller mutations
don't affect engine state — use ``update_record`` for that."""
with self._source_lock(source_name):
record = self._records.get(source_name, {}).get(download_id)
return dict(record) if record is not None else None
def iter_records_for_source(self, source_name: str) -> Iterator[DownloadRecord]:
"""Yield SHALLOW COPIES of every record owned by a source.
Holds the source's lock briefly to snapshot, then yields
outside the lock so callers can spend arbitrary time on each
record.
With the nested-dict layout this is O(source_records) — only
touches the bucket for the requested source, not every record
across every source.
"""
with self._source_lock(source_name):
source_bucket = self._records.get(source_name, {})
snapshot = [dict(record) for record in source_bucket.values()]
for record in snapshot:
yield record
# ------------------------------------------------------------------
# Cross-source query dispatch — Phase B2 surface
# ------------------------------------------------------------------
#
# The orchestrator historically iterated every plugin in its own
# ``get_all_downloads`` / ``get_download_status`` / ``cancel_download``
# methods (with hand-maintained client lists, before the registry
# came along). That iteration logic moves into the engine here so
# the orchestrator becomes a thin pass-through (Phase B3).
#
# In Phase B these methods iterate the registered plugins and call
# their existing ``get_all_downloads`` / ``cancel_download``
# methods — same behavior as today, just in a new home. Phase C/D
# will replace plugin-iteration with direct engine-state queries
# once the thread worker is also lifted.
#
# All methods are async to match the per-plugin contract.
async def get_all_downloads(self, exclude: Tuple[str, ...] = ()):
"""Aggregated view across every registered plugin's active
downloads. Per-plugin exceptions are swallowed (one source
failing shouldn't take down cross-source aggregation) but
logged at debug level — same defensive shape the legacy
orchestrator had.
``exclude`` skips named sources entirely. The download monitor
passes ``('soulseek',)`` so it doesn't double-fetch slskd
transfers (it already pulled them via the slskd transfers
endpoint earlier in the same loop).
"""
all_downloads = []
for source_name, plugin in self._plugins.items():
if plugin is None or source_name in exclude:
continue
try:
all_downloads.extend(await plugin.get_all_downloads())
except Exception as exc:
logger.debug("%s get_all_downloads failed: %s", source_name, exc)
return all_downloads
async def get_download_status(self, download_id: str):
"""Find a download_id across every plugin. Returns the first
plugin's response or None if no plugin owns it."""
for source_name, plugin in self._plugins.items():
if plugin is None:
continue
try:
status = await plugin.get_download_status(download_id)
if status:
return status
except Exception as exc:
logger.debug("%s get_download_status failed: %s", source_name, exc)
return None
async def cancel_download(self, download_id: str,
source_hint: Optional[str] = None,
remove: bool = False) -> bool:
"""Cancel a download. ``source_hint`` is the source name (or
legacy alias like ``'deezer_dl'``, or a real Soulseek peer
username) — when provided, routes directly to that plugin.
When omitted, every plugin is asked in turn until one accepts.
Cin's review caught a bug here: legacy alias strings like
``'deezer_dl'`` weren't resolved to the canonical ``'deezer'``
plugin name, so the cancel silently fell through to Soulseek.
Resolution now goes through ``_resolve_canonical`` first.
"""
# Direct routing when the caller knows the source.
if source_hint:
canonical = self._resolve_canonical(source_hint)
# Streaming source names (or aliases) resolve to a
# registered plugin. Anything else (real Soulseek peer
# name not in our registry) routes to Soulseek.
if canonical and canonical != 'soulseek':
target_plugin = self._plugins.get(canonical)
if target_plugin is not None:
try:
return await target_plugin.cancel_download(
download_id, source_hint, remove,
)
except Exception as exc:
logger.debug("%s cancel_download failed: %s", canonical, exc)
return False
soulseek = self._plugins.get('soulseek')
if soulseek is not None:
try:
return await soulseek.cancel_download(download_id, source_hint, remove)
except Exception as exc:
logger.debug("soulseek cancel_download failed: %s", exc)
return False
# No hint → ask every plugin until one cancels successfully.
for source_name, plugin in self._plugins.items():
if plugin is None:
continue
try:
if await plugin.cancel_download(download_id, source_hint, remove):
return True
except Exception as exc:
logger.debug("%s cancel_download failed: %s", source_name, exc)
return False
async def clear_all_completed_downloads(self) -> bool:
"""Best-effort cleanup of every plugin's completed-downloads
list. Skips plugins that report not-configured (saves API
calls + log noise)."""
results = []
for source_name, plugin in self._plugins.items():
if plugin is None:
continue
if hasattr(plugin, 'is_configured') and not plugin.is_configured():
logger.debug("Skipping %s clear_all_completed_downloads (not configured)", source_name)
continue
try:
results.append(await plugin.clear_all_completed_downloads())
except Exception as exc:
logger.warning("%s clear_all_completed_downloads failed: %s", source_name, exc)
results.append(False)
return all(results) if results else True
# ------------------------------------------------------------------
# Hybrid fallback — Phase F surface
# ------------------------------------------------------------------
async def search_with_fallback(self, query: str, source_chain,
timeout=None, progress_callback=None):
"""Try each source in ``source_chain`` until one returns
tracks. Skips unconfigured / unregistered sources, swallows
per-source exceptions. Returns the first non-empty
(tracks, albums) tuple, or ``([], [])`` when every source
in the chain is exhausted.
Replaces orchestrator's hand-rolled hybrid search loop. The
chain is ordered (most-preferred first).
"""
for i, source_name in enumerate(source_chain):
plugin = self._plugins.get(source_name)
if plugin is None:
logger.info(f"Skipping {source_name} (not available)")
continue
if hasattr(plugin, 'is_configured') and not plugin.is_configured():
logger.info(f"Skipping {source_name} (not configured)")
continue
try:
logger.info(f"Trying {source_name} (priority {i+1}): {query}")
tracks, albums = await plugin.search(query, timeout, progress_callback)
if tracks:
logger.info(f"{source_name} found {len(tracks)} tracks")
return (tracks, albums)
except Exception as e:
logger.warning(f"{source_name} search failed: {e}")
logger.warning(
"Hybrid search: all sources (%s) found nothing for: %s",
', '.join(source_chain), query,
)
return ([], [])
async def download_with_fallback(self, username: str, filename: str,
file_size: int, source_chain) -> Optional[str]:
"""Try each source in ``source_chain`` until one accepts the
download (returns a non-None download_id). Fixes the legacy
bug where hybrid mode silently routed to a single source via
the username hint with no retry on failure.
``username`` is treated as a hint when it matches a source
name in the chain — that source is tried FIRST regardless of
chain order. Anything else (e.g. a real Soulseek peer name)
routes through the chain in declared order.
"""
# Promote a matching source-name hint to the head of the chain.
ordered_chain = list(source_chain)
if username and username in ordered_chain:
ordered_chain.remove(username)
ordered_chain.insert(0, username)
for source_name in ordered_chain:
plugin = self._plugins.get(source_name)
if plugin is None:
continue
if hasattr(plugin, 'is_configured') and not plugin.is_configured():
continue
try:
download_id = await plugin.download(username, filename, file_size)
if download_id is not None:
return download_id
logger.info(f"{source_name} declined download — trying next in chain")
except Exception as e:
logger.warning(f"{source_name} download raised — trying next in chain: {e}")
logger.warning(
"Hybrid download: every source in chain (%s) refused %r",
', '.join(ordered_chain), filename,
)
return None