You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
SoulSync/core/download_engine/engine.py

275 lines
12 KiB

"""DownloadEngine — central owner of cross-source download state.
Phase B scope: skeleton only. The engine exposes a place for
plugins to register, a single ``active_downloads`` dict keyed by
``(source, download_id)``, and a ``state_lock`` that guards mutations
across the multi-threaded download worker pool.
Subsequent phases bolt more capability on top:
- ``dispatch_download(plugin, target_id)`` (Phase C — replaces every
client's ``_download_thread_worker`` boilerplate).
- ``search(query, source_chain)`` (Phase D — replaces every client's
retry ladder + quality filter).
- ``rate_limit.acquire(source)`` (Phase E — replaces every client's
semaphore + last-download-timestamp dance).
- ``search_with_fallback`` / ``download_with_fallback`` (Phase F —
unifies hybrid mode across search and download).
The engine is constructed by ``DownloadOrchestrator.__init__`` and
each plugin from the registry is registered with it. In Phase B
nothing in the existing code paths goes through the engine yet —
this commit is pure additive scaffolding so subsequent commits can
introduce engine-driven behavior one piece at a time without a
big-bang switchover.
"""
from __future__ import annotations
import threading
from typing import Any, Dict, Iterator, List, Optional, Tuple
from utils.logging_config import get_logger
logger = get_logger("download_engine")
# Type alias for the per-download state dict. Today's clients each
# define their own slightly-different shape (see Phase A pinning
# tests); the engine stores them as opaque dicts and the per-plugin
# accessor preserves the source-specific fields.
DownloadRecord = Dict[str, Any]
class DownloadEngine:
"""Central state for every active download across every source.
State is keyed by ``(source_name, download_id)`` so the same
UUID could hypothetically appear in two sources without
collision (in practice each source generates its own UUID4
so collisions are negligible — the source qualifier exists
so the engine can answer "which plugin owns this download" in
O(1) without iterating every plugin).
Thread safety: every state mutation goes through ``state_lock``.
Read-only accessors (``get_record``, ``iter_records_for_source``)
take the lock briefly and return a SHALLOW COPY so the caller
can iterate without holding the lock. Callers that need to
mutate a record should use ``update_record`` which takes the
lock and applies the patch atomically.
"""
def __init__(self) -> None:
self.state_lock = threading.RLock()
# Composite key: (source_name, download_id) → record dict.
# RLock so a plugin's worker callback can re-enter while
# holding the lock for its own update.
self._records: Dict[Tuple[str, str], DownloadRecord] = {}
# Plugins that have registered with the engine. Source name
# → plugin instance. The engine itself doesn't use plugins
# until later phases, but holding the references here keeps
# plugin lookup local to the engine instead of forcing every
# caller to also touch the registry.
self._plugins: Dict[str, Any] = {}
# ------------------------------------------------------------------
# Plugin registration
# ------------------------------------------------------------------
def register_plugin(self, source_name: str, plugin: Any) -> None:
"""Register a plugin under its canonical source name. Called
once per source by the orchestrator after the registry's
``initialize`` builds the client instances.
Phase B is purely informational — the engine doesn't yet
dispatch through plugins. Subsequent phases use these
references to call ``plugin._download_impl`` /
``plugin._search_raw`` etc.
"""
if source_name in self._plugins:
logger.warning("Plugin %s already registered with engine — overwriting", source_name)
self._plugins[source_name] = plugin
def get_plugin(self, source_name: str) -> Optional[Any]:
return self._plugins.get(source_name)
def registered_sources(self) -> List[str]:
return list(self._plugins.keys())
# ------------------------------------------------------------------
# Active-downloads state — Phase B core surface
# ------------------------------------------------------------------
def add_record(self, source_name: str, download_id: str, record: DownloadRecord) -> None:
"""Insert a fresh download record. Used by clients (today
directly via their own dicts; Phase B2 routes them through
here)."""
with self.state_lock:
key = (source_name, download_id)
if key in self._records:
logger.warning("Replacing existing download record for %s/%s", source_name, download_id)
self._records[key] = dict(record)
def update_record(self, source_name: str, download_id: str, patch: DownloadRecord) -> None:
"""Apply a partial patch to an existing record. No-op if the
record was already removed (e.g. cancelled mid-update)."""
with self.state_lock:
existing = self._records.get((source_name, download_id))
if existing is None:
return
existing.update(patch)
def remove_record(self, source_name: str, download_id: str) -> Optional[DownloadRecord]:
"""Delete a record (cancellation cleanup). Returns the
removed record or None if not found."""
with self.state_lock:
return self._records.pop((source_name, download_id), None)
def get_record(self, source_name: str, download_id: str) -> Optional[DownloadRecord]:
"""Return a SHALLOW COPY of the record. Caller mutations
don't affect engine state — use ``update_record`` for that."""
with self.state_lock:
record = self._records.get((source_name, download_id))
return dict(record) if record is not None else None
def iter_records_for_source(self, source_name: str) -> Iterator[DownloadRecord]:
"""Yield SHALLOW COPIES of every record owned by a source.
Holds the lock briefly to snapshot, then yields outside the
lock so callers can spend arbitrary time on each record."""
with self.state_lock:
snapshot = [
dict(record)
for (source, _), record in self._records.items()
if source == source_name
]
for record in snapshot:
yield record
def iter_all_records(self) -> Iterator[Tuple[str, DownloadRecord]]:
"""Yield ``(source_name, record_copy)`` for every active
download across every source. Used by Phase B3's unified
``get_all_downloads`` query."""
with self.state_lock:
snapshot = [
(source, dict(record))
for (source, _), record in self._records.items()
]
for source, record in snapshot:
yield source, record
def find_record(self, download_id: str) -> Optional[Tuple[str, DownloadRecord]]:
"""Look up a record by download_id alone (no source hint).
Used by ``cancel_download`` / ``get_download_status`` API
endpoints that don't pass the source name. Returns
``(source_name, record_copy)`` or None.
O(N) over total downloads — fine for the tens-to-hundreds
of in-flight transfers SoulSync sees, would need an index
if downloads scaled to thousands.
"""
with self.state_lock:
for (source, dl_id), record in self._records.items():
if dl_id == download_id:
return source, dict(record)
return None
# ------------------------------------------------------------------
# Cross-source query dispatch — Phase B2 surface
# ------------------------------------------------------------------
#
# The orchestrator historically iterated every plugin in its own
# ``get_all_downloads`` / ``get_download_status`` / ``cancel_download``
# methods (with hand-maintained client lists, before the registry
# came along). That iteration logic moves into the engine here so
# the orchestrator becomes a thin pass-through (Phase B3).
#
# In Phase B these methods iterate the registered plugins and call
# their existing ``get_all_downloads`` / ``cancel_download``
# methods — same behavior as today, just in a new home. Phase C/D
# will replace plugin-iteration with direct engine-state queries
# once the thread worker is also lifted.
#
# All methods are async to match the per-plugin contract.
async def get_all_downloads(self):
"""Aggregated view across every registered plugin's active
downloads. Returns a flat list of DownloadStatus objects."""
all_downloads = []
for plugin in self._plugins.values():
if plugin is None:
continue
try:
all_downloads.extend(await plugin.get_all_downloads())
except Exception:
pass
return all_downloads
async def get_download_status(self, download_id: str):
"""Find a download_id across every plugin. Returns the first
plugin's response or None if no plugin owns it."""
for plugin in self._plugins.values():
if plugin is None:
continue
try:
status = await plugin.get_download_status(download_id)
if status:
return status
except Exception:
pass
return None
async def cancel_download(self, download_id: str,
source_hint: Optional[str] = None,
remove: bool = False) -> bool:
"""Cancel a download. ``source_hint`` is the source name (or
legacy username string like ``'deezer_dl'``) — when provided,
routes directly to that plugin. When omitted, every plugin
is asked in turn until one accepts the cancel."""
# Direct routing when the caller knows the source.
if source_hint:
# Streaming source names ARE the username. Soulseek
# uses a real peer username (anything not in our plugin
# registry), so route those to the soulseek plugin.
target_plugin = self._plugins.get(source_hint)
if target_plugin is not None and source_hint != 'soulseek':
try:
return await target_plugin.cancel_download(
download_id, source_hint, remove,
)
except Exception:
return False
soulseek = self._plugins.get('soulseek')
if soulseek is not None:
try:
return await soulseek.cancel_download(download_id, source_hint, remove)
except Exception:
return False
# No hint → ask every plugin until one cancels successfully.
for plugin in self._plugins.values():
if plugin is None:
continue
try:
if await plugin.cancel_download(download_id, source_hint, remove):
return True
except Exception:
pass
return False
async def clear_all_completed_downloads(self) -> bool:
"""Best-effort cleanup of every plugin's completed-downloads
list. Skips plugins that report not-configured (saves API
calls + log noise)."""
results = []
for source_name, plugin in self._plugins.items():
if plugin is None:
continue
if hasattr(plugin, 'is_configured') and not plugin.is_configured():
logger.debug("Skipping %s clear_all_completed_downloads (not configured)", source_name)
continue
try:
results.append(await plugin.clear_all_completed_downloads())
except Exception as exc:
logger.warning("%s clear_all_completed_downloads failed: %s", source_name, exc)
results.append(False)
return all(results) if results else True