mirror of https://github.com/Nezreka/SoulSync.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
458 lines
21 KiB
458 lines
21 KiB
"""DownloadEngine — central owner of cross-source download state.
|
|
|
|
Phase B scope: skeleton only. The engine exposes a place for
|
|
plugins to register, a single ``active_downloads`` dict keyed by
|
|
``(source, download_id)``, and per-source RLocks that guard mutations
|
|
without serializing workers across different sources.
|
|
|
|
Subsequent phases bolt more capability on top:
|
|
- ``dispatch_download(plugin, target_id)`` (Phase C — replaces every
|
|
client's ``_download_thread_worker`` boilerplate).
|
|
- ``search(query, source_chain)`` (Phase D — replaces every client's
|
|
retry ladder + quality filter).
|
|
- ``rate_limit.acquire(source)`` (Phase E — replaces every client's
|
|
semaphore + last-download-timestamp dance).
|
|
- ``search_with_fallback`` / ``download_with_fallback`` (Phase F —
|
|
unifies hybrid mode across search and download).
|
|
|
|
The engine is constructed by ``DownloadOrchestrator.__init__`` and
|
|
each plugin from the registry is registered with it. In Phase B
|
|
nothing in the existing code paths goes through the engine yet —
|
|
this commit is pure additive scaffolding so subsequent commits can
|
|
introduce engine-driven behavior one piece at a time without a
|
|
big-bang switchover.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import threading
|
|
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
|
|
|
from utils.logging_config import get_logger
|
|
|
|
logger = get_logger("download_engine")
|
|
|
|
|
|
# Type alias for the per-download state dict. Today's clients each
|
|
# define their own slightly-different shape (see Phase A pinning
|
|
# tests); the engine stores them as opaque dicts and the per-plugin
|
|
# accessor preserves the source-specific fields.
|
|
DownloadRecord = Dict[str, Any]
|
|
|
|
|
|
class DownloadEngine:
|
|
"""Central state for every active download across every source.
|
|
|
|
State is keyed by ``(source_name, download_id)`` so the same
|
|
UUID could hypothetically appear in two sources without
|
|
collision (in practice each source generates its own UUID4
|
|
so collisions are negligible — the source qualifier exists
|
|
so the engine can answer "which plugin owns this download" in
|
|
O(1) without iterating every plugin).
|
|
|
|
Thread safety: per-source lock sharding. Each source gets its own
|
|
RLock — progress callbacks on Deezer don't block Tidal's worker
|
|
and vice versa, matching the pre-refactor behavior where each
|
|
client owned its own download lock. Read-only accessors
|
|
(``get_record``, ``iter_records_for_source``) take the source's
|
|
lock briefly and return a SHALLOW COPY so the caller can iterate
|
|
without holding the lock. Callers that need to mutate a record
|
|
should use ``update_record`` which takes the lock and applies the
|
|
patch atomically.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
# Nested dict: source_name → {download_id → record}. Replaces
|
|
# the original single-dict composite-key layout so
|
|
# ``iter_records_for_source`` is O(source_records) instead of
|
|
# O(total_records).
|
|
self._records: Dict[str, Dict[str, DownloadRecord]] = {}
|
|
# Per-source RLocks. Each source gets its own so progress
|
|
# updates on one source never block writes on another. RLock
|
|
# so a plugin's worker callback can re-enter while holding the
|
|
# lock for its own update. Lazily created via ``_source_lock``;
|
|
# the meta-lock guards creation against the create-race window
|
|
# where two threads could both miss + both create.
|
|
self._source_locks: Dict[str, threading.RLock] = {}
|
|
self._source_locks_lock = threading.Lock()
|
|
# Plugins that have registered with the engine. Source name
|
|
# → plugin instance.
|
|
self._plugins: Dict[str, Any] = {}
|
|
# Alias → canonical-name map. Lets engine resolve legacy
|
|
# source-name strings (e.g. ``'deezer_dl'`` for Deezer) to
|
|
# the canonical key in ``_plugins``. Cin's review caught
|
|
# that engine.cancel_download(source_hint='deezer_dl')
|
|
# silently fell through to Soulseek because alias resolution
|
|
# only existed at the registry, not on the engine.
|
|
self._aliases: Dict[str, str] = {}
|
|
# Background download worker — lives on the engine because
|
|
# it owns the cross-source state the worker mutates. Lazy
|
|
# import keeps the engine module standalone.
|
|
from core.download_engine.worker import BackgroundDownloadWorker
|
|
self.worker = BackgroundDownloadWorker(self)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Plugin registration
|
|
# ------------------------------------------------------------------
|
|
|
|
def register_plugin(self, source_name: str, plugin: Any,
|
|
aliases: Tuple[str, ...] = ()) -> None:
|
|
"""Register a plugin under its canonical source name. Called
|
|
once per source by the orchestrator after the registry's
|
|
``initialize`` builds the client instances.
|
|
|
|
``aliases`` is the list of legacy source-name strings that
|
|
should resolve to this plugin (e.g. ``'deezer_dl'`` for
|
|
Deezer). Without alias resolution the engine couldn't route
|
|
cancel/lookup calls that came in with the legacy name.
|
|
|
|
If the plugin exposes ``set_engine(engine)``, the engine
|
|
passes a self-reference so the plugin can dispatch into
|
|
``engine.worker`` / read state / etc. Plugins that haven't
|
|
been migrated to the engine yet simply don't define
|
|
``set_engine`` — they keep their pre-engine behavior
|
|
unchanged.
|
|
|
|
Also reads the plugin's declared ``RateLimitPolicy`` (via
|
|
the ``rate_limit_policy()`` method or ``RATE_LIMIT_POLICY``
|
|
class attribute) and applies it to the worker. Plugins that
|
|
don't declare a policy get the conservative default
|
|
(concurrency=1, delay=0).
|
|
"""
|
|
if source_name in self._plugins:
|
|
logger.warning("Plugin %s already registered with engine — overwriting", source_name)
|
|
self._plugins[source_name] = plugin
|
|
for alias in aliases:
|
|
self._aliases[alias] = source_name
|
|
|
|
# Apply the plugin's rate-limit policy BEFORE set_engine so
|
|
# set_engine callbacks can override per-source if they need
|
|
# config-driven values (e.g. YouTube's user-tunable delay).
|
|
from core.download_engine.rate_limit import resolve_policy
|
|
policy = resolve_policy(plugin)
|
|
self.worker.set_concurrency(source_name, policy.download_concurrency)
|
|
self.worker.set_delay(source_name, policy.download_delay_seconds)
|
|
|
|
set_engine = getattr(plugin, 'set_engine', None)
|
|
if callable(set_engine):
|
|
try:
|
|
set_engine(self)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"Plugin %s set_engine callback failed: %s", source_name, exc,
|
|
)
|
|
|
|
def get_plugin(self, source_name: str) -> Optional[Any]:
|
|
"""Return the plugin instance for the given source name.
|
|
Resolves through aliases — e.g. ``get_plugin('deezer_dl')``
|
|
returns the same instance as ``get_plugin('deezer')``."""
|
|
if source_name in self._plugins:
|
|
return self._plugins[source_name]
|
|
canonical = self._aliases.get(source_name)
|
|
if canonical:
|
|
return self._plugins.get(canonical)
|
|
return None
|
|
|
|
def _resolve_canonical(self, source_name: str) -> Optional[str]:
|
|
"""Return the canonical source name for an input that may be
|
|
an alias. Returns None if the input matches neither a
|
|
canonical name nor an alias."""
|
|
if source_name in self._plugins:
|
|
return source_name
|
|
return self._aliases.get(source_name)
|
|
|
|
def registered_sources(self) -> List[str]:
|
|
return list(self._plugins.keys())
|
|
|
|
def _source_lock(self, source_name: str) -> threading.RLock:
|
|
"""Return the per-source RLock, lazy-creating it on first use.
|
|
The meta-lock around the cache lookup closes the create-race
|
|
window where two threads both miss + both create a fresh lock.
|
|
"""
|
|
with self._source_locks_lock:
|
|
lock = self._source_locks.get(source_name)
|
|
if lock is None:
|
|
lock = threading.RLock()
|
|
self._source_locks[source_name] = lock
|
|
return lock
|
|
|
|
# ------------------------------------------------------------------
|
|
# Active-downloads state — Phase B core surface
|
|
# ------------------------------------------------------------------
|
|
|
|
def add_record(self, source_name: str, download_id: str, record: DownloadRecord) -> None:
|
|
"""Insert a fresh download record. Used by clients (today
|
|
directly via their own dicts; Phase B2 routes them through
|
|
here)."""
|
|
with self._source_lock(source_name):
|
|
source_bucket = self._records.setdefault(source_name, {})
|
|
if download_id in source_bucket:
|
|
logger.warning("Replacing existing download record for %s/%s", source_name, download_id)
|
|
source_bucket[download_id] = dict(record)
|
|
|
|
def update_record(self, source_name: str, download_id: str, patch: DownloadRecord) -> None:
|
|
"""Apply a partial patch to an existing record. No-op if the
|
|
record was already removed (e.g. cancelled mid-update)."""
|
|
with self._source_lock(source_name):
|
|
existing = self._records.get(source_name, {}).get(download_id)
|
|
if existing is None:
|
|
return
|
|
existing.update(patch)
|
|
|
|
def update_record_unless_state(self, source_name: str, download_id: str,
|
|
patch: DownloadRecord,
|
|
skip_if_state_in: Tuple[str, ...] = ()) -> bool:
|
|
"""Atomically check the record's state and apply ``patch`` only
|
|
if the current state is NOT in ``skip_if_state_in``. Returns
|
|
True if the patch was applied, False if it was skipped (or
|
|
the record didn't exist).
|
|
|
|
Used by the background download worker's ``_mark_terminal``
|
|
to avoid the read-then-write race Cin flagged: a cancel
|
|
landing between the snapshot and update could be overwritten
|
|
back to Errored / Completed. Holding the source's lock across
|
|
the check + write closes the window.
|
|
"""
|
|
with self._source_lock(source_name):
|
|
existing = self._records.get(source_name, {}).get(download_id)
|
|
if existing is None:
|
|
return False
|
|
if existing.get('state') in skip_if_state_in:
|
|
return False
|
|
existing.update(patch)
|
|
return True
|
|
|
|
def remove_record(self, source_name: str, download_id: str) -> Optional[DownloadRecord]:
|
|
"""Delete a record (cancellation cleanup). Returns the
|
|
removed record or None if not found."""
|
|
with self._source_lock(source_name):
|
|
source_bucket = self._records.get(source_name)
|
|
if not source_bucket:
|
|
return None
|
|
removed = source_bucket.pop(download_id, None)
|
|
# Drop the empty source bucket so iteration / membership
|
|
# checks don't see a stale source key.
|
|
if not source_bucket:
|
|
self._records.pop(source_name, None)
|
|
return removed
|
|
|
|
def get_record(self, source_name: str, download_id: str) -> Optional[DownloadRecord]:
|
|
"""Return a SHALLOW COPY of the record. Caller mutations
|
|
don't affect engine state — use ``update_record`` for that."""
|
|
with self._source_lock(source_name):
|
|
record = self._records.get(source_name, {}).get(download_id)
|
|
return dict(record) if record is not None else None
|
|
|
|
def iter_records_for_source(self, source_name: str) -> Iterator[DownloadRecord]:
|
|
"""Yield SHALLOW COPIES of every record owned by a source.
|
|
Holds the source's lock briefly to snapshot, then yields
|
|
outside the lock so callers can spend arbitrary time on each
|
|
record.
|
|
|
|
With the nested-dict layout this is O(source_records) — only
|
|
touches the bucket for the requested source, not every record
|
|
across every source.
|
|
"""
|
|
with self._source_lock(source_name):
|
|
source_bucket = self._records.get(source_name, {})
|
|
snapshot = [dict(record) for record in source_bucket.values()]
|
|
for record in snapshot:
|
|
yield record
|
|
|
|
# ------------------------------------------------------------------
|
|
# Cross-source query dispatch — Phase B2 surface
|
|
# ------------------------------------------------------------------
|
|
#
|
|
# The orchestrator historically iterated every plugin in its own
|
|
# ``get_all_downloads`` / ``get_download_status`` / ``cancel_download``
|
|
# methods (with hand-maintained client lists, before the registry
|
|
# came along). That iteration logic moves into the engine here so
|
|
# the orchestrator becomes a thin pass-through (Phase B3).
|
|
#
|
|
# In Phase B these methods iterate the registered plugins and call
|
|
# their existing ``get_all_downloads`` / ``cancel_download``
|
|
# methods — same behavior as today, just in a new home. Phase C/D
|
|
# will replace plugin-iteration with direct engine-state queries
|
|
# once the thread worker is also lifted.
|
|
#
|
|
# All methods are async to match the per-plugin contract.
|
|
|
|
async def get_all_downloads(self, exclude: Tuple[str, ...] = ()):
|
|
"""Aggregated view across every registered plugin's active
|
|
downloads. Per-plugin exceptions are swallowed (one source
|
|
failing shouldn't take down cross-source aggregation) but
|
|
logged at debug level — same defensive shape the legacy
|
|
orchestrator had.
|
|
|
|
``exclude`` skips named sources entirely. The download monitor
|
|
passes ``('soulseek',)`` so it doesn't double-fetch slskd
|
|
transfers (it already pulled them via the slskd transfers
|
|
endpoint earlier in the same loop).
|
|
"""
|
|
all_downloads = []
|
|
for source_name, plugin in self._plugins.items():
|
|
if plugin is None or source_name in exclude:
|
|
continue
|
|
try:
|
|
all_downloads.extend(await plugin.get_all_downloads())
|
|
except Exception as exc:
|
|
logger.debug("%s get_all_downloads failed: %s", source_name, exc)
|
|
return all_downloads
|
|
|
|
async def get_download_status(self, download_id: str):
|
|
"""Find a download_id across every plugin. Returns the first
|
|
plugin's response or None if no plugin owns it."""
|
|
for source_name, plugin in self._plugins.items():
|
|
if plugin is None:
|
|
continue
|
|
try:
|
|
status = await plugin.get_download_status(download_id)
|
|
if status:
|
|
return status
|
|
except Exception as exc:
|
|
logger.debug("%s get_download_status failed: %s", source_name, exc)
|
|
return None
|
|
|
|
async def cancel_download(self, download_id: str,
|
|
source_hint: Optional[str] = None,
|
|
remove: bool = False) -> bool:
|
|
"""Cancel a download. ``source_hint`` is the source name (or
|
|
legacy alias like ``'deezer_dl'``, or a real Soulseek peer
|
|
username) — when provided, routes directly to that plugin.
|
|
When omitted, every plugin is asked in turn until one accepts.
|
|
|
|
Cin's review caught a bug here: legacy alias strings like
|
|
``'deezer_dl'`` weren't resolved to the canonical ``'deezer'``
|
|
plugin name, so the cancel silently fell through to Soulseek.
|
|
Resolution now goes through ``_resolve_canonical`` first.
|
|
"""
|
|
# Direct routing when the caller knows the source.
|
|
if source_hint:
|
|
canonical = self._resolve_canonical(source_hint)
|
|
# Streaming source names (or aliases) resolve to a
|
|
# registered plugin. Anything else (real Soulseek peer
|
|
# name not in our registry) routes to Soulseek.
|
|
if canonical and canonical != 'soulseek':
|
|
target_plugin = self._plugins.get(canonical)
|
|
if target_plugin is not None:
|
|
try:
|
|
return await target_plugin.cancel_download(
|
|
download_id, source_hint, remove,
|
|
)
|
|
except Exception as exc:
|
|
logger.debug("%s cancel_download failed: %s", canonical, exc)
|
|
return False
|
|
soulseek = self._plugins.get('soulseek')
|
|
if soulseek is not None:
|
|
try:
|
|
return await soulseek.cancel_download(download_id, source_hint, remove)
|
|
except Exception as exc:
|
|
logger.debug("soulseek cancel_download failed: %s", exc)
|
|
return False
|
|
|
|
# No hint → ask every plugin until one cancels successfully.
|
|
for source_name, plugin in self._plugins.items():
|
|
if plugin is None:
|
|
continue
|
|
try:
|
|
if await plugin.cancel_download(download_id, source_hint, remove):
|
|
return True
|
|
except Exception as exc:
|
|
logger.debug("%s cancel_download failed: %s", source_name, exc)
|
|
return False
|
|
|
|
async def clear_all_completed_downloads(self) -> bool:
|
|
"""Best-effort cleanup of every plugin's completed-downloads
|
|
list. Skips plugins that report not-configured (saves API
|
|
calls + log noise)."""
|
|
results = []
|
|
for source_name, plugin in self._plugins.items():
|
|
if plugin is None:
|
|
continue
|
|
if hasattr(plugin, 'is_configured') and not plugin.is_configured():
|
|
logger.debug("Skipping %s clear_all_completed_downloads (not configured)", source_name)
|
|
continue
|
|
try:
|
|
results.append(await plugin.clear_all_completed_downloads())
|
|
except Exception as exc:
|
|
logger.warning("%s clear_all_completed_downloads failed: %s", source_name, exc)
|
|
results.append(False)
|
|
return all(results) if results else True
|
|
|
|
# ------------------------------------------------------------------
|
|
# Hybrid fallback — Phase F surface
|
|
# ------------------------------------------------------------------
|
|
|
|
async def search_with_fallback(self, query: str, source_chain,
|
|
timeout=None, progress_callback=None):
|
|
"""Try each source in ``source_chain`` until one returns
|
|
tracks. Skips unconfigured / unregistered sources, swallows
|
|
per-source exceptions. Returns the first non-empty
|
|
(tracks, albums) tuple, or ``([], [])`` when every source
|
|
in the chain is exhausted.
|
|
|
|
Replaces orchestrator's hand-rolled hybrid search loop. The
|
|
chain is ordered (most-preferred first).
|
|
"""
|
|
for i, source_name in enumerate(source_chain):
|
|
plugin = self._plugins.get(source_name)
|
|
if plugin is None:
|
|
logger.info(f"Skipping {source_name} (not available)")
|
|
continue
|
|
if hasattr(plugin, 'is_configured') and not plugin.is_configured():
|
|
logger.info(f"Skipping {source_name} (not configured)")
|
|
continue
|
|
|
|
try:
|
|
logger.info(f"Trying {source_name} (priority {i+1}): {query}")
|
|
tracks, albums = await plugin.search(query, timeout, progress_callback)
|
|
if tracks:
|
|
logger.info(f"{source_name} found {len(tracks)} tracks")
|
|
return (tracks, albums)
|
|
except Exception as e:
|
|
logger.warning(f"{source_name} search failed: {e}")
|
|
|
|
logger.warning(
|
|
"Hybrid search: all sources (%s) found nothing for: %s",
|
|
', '.join(source_chain), query,
|
|
)
|
|
return ([], [])
|
|
|
|
async def download_with_fallback(self, username: str, filename: str,
|
|
file_size: int, source_chain) -> Optional[str]:
|
|
"""Try each source in ``source_chain`` until one accepts the
|
|
download (returns a non-None download_id). Fixes the legacy
|
|
bug where hybrid mode silently routed to a single source via
|
|
the username hint with no retry on failure.
|
|
|
|
``username`` is treated as a hint when it matches a source
|
|
name in the chain — that source is tried FIRST regardless of
|
|
chain order. Anything else (e.g. a real Soulseek peer name)
|
|
routes through the chain in declared order.
|
|
"""
|
|
# Promote a matching source-name hint to the head of the chain.
|
|
ordered_chain = list(source_chain)
|
|
if username and username in ordered_chain:
|
|
ordered_chain.remove(username)
|
|
ordered_chain.insert(0, username)
|
|
|
|
for source_name in ordered_chain:
|
|
plugin = self._plugins.get(source_name)
|
|
if plugin is None:
|
|
continue
|
|
if hasattr(plugin, 'is_configured') and not plugin.is_configured():
|
|
continue
|
|
try:
|
|
download_id = await plugin.download(username, filename, file_size)
|
|
if download_id is not None:
|
|
return download_id
|
|
logger.info(f"{source_name} declined download — trying next in chain")
|
|
except Exception as e:
|
|
logger.warning(f"{source_name} download raised — trying next in chain: {e}")
|
|
|
|
logger.warning(
|
|
"Hybrid download: every source in chain (%s) refused %r",
|
|
', '.join(ordered_chain), filename,
|
|
)
|
|
return None
|