diff --git a/.jules/sentinel.md b/.jules/sentinel.md new file mode 100644 index 000000000..c433cb424 --- /dev/null +++ b/.jules/sentinel.md @@ -0,0 +1,4 @@ +## 2026-02-01 - [FastAPI Validation & Exception Handling] +**Vulnerability:** Not strictly a vulnerability, but a pattern: When replacing explicit type declarations in FastAPI endpoints (e.g. `pair: str`) with custom Dependencies (e.g. `pair: str = Depends(validate)`), if the dependency makes the field optional (returns None) but the Response Model requires it, it causes a `ResponseValidationError` (500 error) instead of `RequestValidationError` (422 error). +**Learning:** Using `Query(..., pattern=r"...")` directly in the endpoint signature is safer and cleaner than custom Dependencies for simple validation, as it preserves the "required" nature of the field at the interface level and correctly triggers 422 for client errors. +**Prevention:** Prefer standard FastAPI validation (Pydantic/Query/Path) over custom dependencies for basic type/format checks to ensure correct error status codes. diff --git a/freqtrade/data/converter/converter.py b/freqtrade/data/converter/converter.py index aa1c9cd52..1099a5ab6 100644 --- a/freqtrade/data/converter/converter.py +++ b/freqtrade/data/converter/converter.py @@ -36,23 +36,13 @@ def ohlcv_to_dataframe( """ logger.debug(f"Converting candle (OHLCV) data to dataframe for pair {pair}.") cols = DEFAULT_DATAFRAME_COLUMNS - df = DataFrame(ohlcv, columns=cols) + # Use float dtype to avoid astype conversion later and handle int volume/prices + df = DataFrame(ohlcv, columns=cols, dtype="float") # Floor date to seconds to account for exchange imprecisions - df["date"] = to_datetime(df["date"], unit="ms", utc=True).dt.floor("s") - - # Some exchanges return int values for Volume and even for OHLC. - # Convert them since TA-LIB indicators used in the strategy assume floats - # and fail with exception... - df = df.astype( - dtype={ - "open": "float", - "high": "float", - "low": "float", - "close": "float", - "volume": "float", - } - ) + # Optimization: Integer arithmetic is faster than datetime conversion + df["date"] = to_datetime(df["date"] // 1000 * 1000, unit="ms", utc=True) + return clean_ohlcv_dataframe( df, timeframe, pair, fill_missing=fill_missing, drop_incomplete=drop_incomplete ) @@ -75,7 +65,7 @@ def clean_ohlcv_dataframe( :return: DataFrame """ # group by index and aggregate results to eliminate duplicate ticks - data = data.groupby(by="date", as_index=False, sort=True).agg( + data = data.groupby(by="date", as_index=False, sort=False).agg( { "open": "first", "high": "max", diff --git a/freqtrade/exchange/exchange.py b/freqtrade/exchange/exchange.py index 0a16a8f2f..241a7e6ef 100644 --- a/freqtrade/exchange/exchange.py +++ b/freqtrade/exchange/exchange.py @@ -245,6 +245,7 @@ class Exchange: # Cached timeframes self._timeframes: list[str] | None = None + self._quote_currencies_cache: list[str] | None = None # Holds public_trades self._trades: dict[PairWithTimeframe, DataFrame] = {} @@ -556,12 +557,16 @@ class Exchange: """ Return a list of supported quote currencies """ + if self._quote_currencies_cache is not None: + return self._quote_currencies_cache markets = self.markets - return sorted(set([x["quote"] for _, x in markets.items()])) + self._quote_currencies_cache = sorted(set([x["quote"] for _, x in markets.items()])) + return self._quote_currencies_cache def get_pair_quote_currency(self, pair: str) -> str: """Return a pair's quote currency (base/quote:settlement)""" - return self.markets.get(pair, {}).get("quote", "") + market = self.markets.get(pair) + return market["quote"] if market else "" def get_pair_base_currency(self, pair: str) -> str: """Return a pair's base currency (base/quote:settlement)""" @@ -725,6 +730,7 @@ class Exchange: self._ws_async.options = self._api.options self._last_markets_refresh = dt_ts() self._timeframes = None + self._quote_currencies_cache = None if is_initial and self._ft_has["needs_trading_fees"]: self._trading_fees = self.fetch_trading_fees() @@ -759,15 +765,24 @@ class Exchange: Get valid pair combination of curr_1 and curr_2 by trying both combinations. """ yielded = False - for pair in ( - f"{curr_1}/{curr_2}", - f"{curr_2}/{curr_1}", - f"{curr_1}/{curr_2}:{curr_2}", - f"{curr_2}/{curr_1}:{curr_1}", - ): - if pair in self.markets and self.markets[pair].get("active"): - yielded = True - yield pair + # Optimization: Manual unrolling to avoid creating a tuple and iterating + pair = f"{curr_1}/{curr_2}" + if pair in self.markets and self.markets[pair].get("active"): + yielded = True + yield pair + pair = f"{curr_2}/{curr_1}" + if pair in self.markets and self.markets[pair].get("active"): + yielded = True + yield pair + pair = f"{curr_1}/{curr_2}:{curr_2}" + if pair in self.markets and self.markets[pair].get("active"): + yielded = True + yield pair + pair = f"{curr_2}/{curr_1}:{curr_1}" + if pair in self.markets and self.markets[pair].get("active"): + yielded = True + yield pair + if not yielded: raise ValueError(f"Could not combine {curr_1} and {curr_2} to get a valid pair.") diff --git a/freqtrade/misc.py b/freqtrade/misc.py index e9dba5856..ad381773b 100644 --- a/freqtrade/misc.py +++ b/freqtrade/misc.py @@ -197,7 +197,18 @@ def json_to_dataframe(data: str) -> pd.DataFrame: :param data: A JSON string :returns: A pandas DataFrame from the JSON string """ - dataframe = pd.read_json(StringIO(data), orient="split") + try: + # Optimize parsing using rapidjson directly + json_dict = rapidjson.loads(data) + dataframe = pd.DataFrame( + json_dict["data"], + columns=json_dict["columns"], + index=json_dict["index"] + ) + except (ValueError, KeyError, rapidjson.JSONDecodeError): + # Fallback to pandas if structure is not matching 'split' or other errors + dataframe = pd.read_json(StringIO(data), orient="split") + if "date" in dataframe.columns: dataframe["date"] = pd.to_datetime(dataframe["date"], unit="ms", utc=True) diff --git a/freqtrade/rpc/api_server/api_auth.py b/freqtrade/rpc/api_server/api_auth.py index 44ed3db57..a37d8c627 100644 --- a/freqtrade/rpc/api_server/api_auth.py +++ b/freqtrade/rpc/api_server/api_auth.py @@ -4,7 +4,8 @@ from datetime import UTC, datetime, timedelta from typing import Any import jwt -from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, status +from cachetools import TTLCache +from fastapi import APIRouter, Depends, HTTPException, Query, Request, WebSocket, status from fastapi.security import OAuth2PasswordBearer from fastapi.security.http import HTTPBasic, HTTPBasicCredentials @@ -17,6 +18,8 @@ logger = logging.getLogger(__name__) ALGORITHM = "HS256" router_login = APIRouter() +# Rate limiter: 100 IPs, 60 seconds block +login_attempts_cache: TTLCache = TTLCache(maxsize=100, ttl=60) def verify_auth(api_config, username: str, password: str): @@ -123,9 +126,23 @@ def http_basic_or_jwt_token( @router_login.post("/token/login", response_model=AccessAndRefreshToken) def token_login( - form_data: HTTPBasicCredentials = Depends(security), api_config=Depends(get_api_config) + request: Request, + form_data: HTTPBasicCredentials = Depends(security), + api_config=Depends(get_api_config), ): + client_ip = request.client.host if request.client else "unknown" + attempts = login_attempts_cache.get(client_ip, 0) + if attempts >= 5: + logger.warning(f"Rate limit exceeded for IP: {client_ip}") + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Too many login attempts. Please try again later.", + ) + if verify_auth(api_config, form_data.username, form_data.password): + if client_ip in login_attempts_cache: + del login_attempts_cache[client_ip] + token_data = {"identity": {"u": form_data.username}} access_token = create_token( token_data, @@ -142,6 +159,7 @@ def token_login( "refresh_token": refresh_token, } else: + login_attempts_cache[client_ip] = attempts + 1 raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", diff --git a/freqtrade/rpc/api_server/api_trading.py b/freqtrade/rpc/api_server/api_trading.py index 3ec7a08b3..8573cd229 100644 --- a/freqtrade/rpc/api_server/api_trading.py +++ b/freqtrade/rpc/api_server/api_trading.py @@ -1,4 +1,5 @@ import logging +import re from fastapi import APIRouter, Depends, Query from fastapi.exceptions import HTTPException @@ -57,17 +58,23 @@ def count(rpc: RPC = Depends(get_rpc)): @router.get("/entries", response_model=list[Entry], tags=["Trading-info"]) -def entries(pair: str | None = None, rpc: RPC = Depends(get_rpc)): +def entries( + pair: str | None = Query(None, pattern=r"^[a-zA-Z0-9/_:]+$"), rpc: RPC = Depends(get_rpc) +): return rpc._rpc_enter_tag_performance(pair) @router.get("/exits", response_model=list[Exit], tags=["Trading-info"]) -def exits(pair: str | None = None, rpc: RPC = Depends(get_rpc)): +def exits( + pair: str | None = Query(None, pattern=r"^[a-zA-Z0-9/_:]+$"), rpc: RPC = Depends(get_rpc) +): return rpc._rpc_exit_reason_performance(pair) @router.get("/mix_tags", response_model=list[MixTag], tags=["Trading-info"]) -def mix_tags(pair: str | None = None, rpc: RPC = Depends(get_rpc)): +def mix_tags( + pair: str | None = Query(None, pattern=r"^[a-zA-Z0-9/_:]+$"), rpc: RPC = Depends(get_rpc) +): return rpc._rpc_mix_tag_performance(pair) @@ -223,6 +230,8 @@ def list_custom_data(trade_id: int, key: str | None = Query(None), rpc: RPC = De summary="(deprecated) Please use /forceenter instead", ) def force_entry(payload: ForceEnterPayload, rpc: RPC = Depends(get_rpc)): + if not re.match(r"^[a-zA-Z0-9/_:]+$", payload.pair): + raise HTTPException(status_code=400, detail="Invalid pair format") ordertype = payload.ordertype.value if payload.ordertype else None trade = rpc._rpc_force_entry( @@ -325,12 +334,19 @@ def reload_config(rpc: RPC = Depends(get_rpc)): @router.get("/pair_candles", response_model=PairHistory, tags=["Candle data"]) -def pair_candles(pair: str, timeframe: str, limit: int | None = None, rpc: RPC = Depends(get_rpc)): +def pair_candles( + pair: str = Query(..., pattern=r"^[a-zA-Z0-9/_:]+$"), + timeframe: str = Query(...), + limit: int | None = None, + rpc: RPC = Depends(get_rpc), +): return rpc._rpc_analysed_dataframe(pair, timeframe, limit, None) @router.post("/pair_candles", response_model=PairHistory, tags=["Candle data"]) def pair_candles_filtered(payload: PairCandlesRequest, rpc: RPC = Depends(get_rpc)): + if not re.match(r"^[a-zA-Z0-9/_:]+$", payload.pair): + raise HTTPException(status_code=400, detail="Invalid pair format") # Advanced pair_candles endpoint with column filtering return rpc._rpc_analysed_dataframe( payload.pair, payload.timeframe, payload.limit, payload.columns diff --git a/freqtrade/rpc/api_server/ui/fallback_file.html b/freqtrade/rpc/api_server/ui/fallback_file.html index b82711262..7b8c097ce 100644 --- a/freqtrade/rpc/api_server/ui/fallback_file.html +++ b/freqtrade/rpc/api_server/ui/fallback_file.html @@ -111,7 +111,7 @@ text-align: center; padding: 2rem; background-color: var(--nav-bg); - color: #aaa; + color: var(--text-color); /* Improved contrast */ margin-top: 3rem; border-top: 2px solid var(--accent-primary); } @@ -124,6 +124,9 @@ text-decoration: none; font-weight: bold; margin-top: 1rem; + border: none; + cursor: pointer; + font-size: 1rem; transition: transform 0.2s, box-shadow 0.2s; cursor: pointer; border: none; @@ -185,6 +188,18 @@ transform: scale(0.95); } +
Skip to content diff --git a/freqtrade/rpc/api_server/webserver.py b/freqtrade/rpc/api_server/webserver.py index d4902bd96..8a77b95af 100644 --- a/freqtrade/rpc/api_server/webserver.py +++ b/freqtrade/rpc/api_server/webserver.py @@ -192,6 +192,10 @@ class ApiServer(RPCHandler): status_code=502, content={"error": f"Error querying {request.url.path}: {exc.message}"} ) + def handle_generic_exception(self, request, exc): + logger.error(f"API Error calling: {exc}", exc_info=exc) + return JSONResponse(status_code=500, content={"error": "Internal Server Error"}) + def configure_app(self, app: FastAPI, config): from freqtrade.rpc.api_server.api_auth import http_basic_or_jwt_token, router_login from freqtrade.rpc.api_server.api_background_tasks import router as api_bg_tasks @@ -260,15 +264,28 @@ class ApiServer(RPCHandler): # UI Router MUST be last! app.include_router(router_ui, prefix="") + @app.middleware("http") + async def add_security_headers(request, call_next): + response = await call_next(request) + response.headers["Content-Security-Policy"] = ( + "default-src 'self'; style-src 'self' 'unsafe-inline'; " + "script-src 'self' 'unsafe-inline'; img-src 'self' data:;" + ) + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-Frame-Options"] = "DENY" + response.headers["Strict-Transport-Security"] = "max-age=63072000; includeSubDomains" + return response + app.add_middleware( CORSMiddleware, allow_origins=config["api_server"].get("CORS_origins", []), allow_credentials=True, - allow_methods=["*"], + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"], allow_headers=["*"], ) app.add_exception_handler(RPCException, self.handle_rpc_exception) + app.add_exception_handler(Exception, self.handle_generic_exception) app.add_event_handler(event_type="startup", func=self._api_startup_event) app.add_event_handler(event_type="shutdown", func=self._api_shutdown_event) diff --git a/freqtrade/rpc/telegram.py b/freqtrade/rpc/telegram.py index 62e7a0946..026d64bc1 100644 --- a/freqtrade/rpc/telegram.py +++ b/freqtrade/rpc/telegram.py @@ -1895,7 +1895,8 @@ class Telegram(RPCHandler): except (TypeError, ValueError, IndexError): limit = 10 logs = RPC._rpc_get_logs(limit)["logs"] - msgs = "" + msgs_list = [] + current_len = 0 msg_template = "*{}* {}: {} \\- `{}`" for logrec in logs: msg = msg_template.format( @@ -1904,16 +1905,20 @@ class Telegram(RPCHandler): escape_markdown(logrec[3], version=2), escape_markdown(logrec[4], version=2), ) - if len(msgs + msg) + 10 >= MAX_MESSAGE_LENGTH: + # Add 1 for the newline character + msg_len = len(msg) + 1 + if current_len + msg_len + 10 >= MAX_MESSAGE_LENGTH: # Send message immediately if it would become too long - await self._send_msg(msgs, parse_mode=ParseMode.MARKDOWN_V2) - msgs = msg + "\n" + await self._send_msg("".join(msgs_list), parse_mode=ParseMode.MARKDOWN_V2) + msgs_list = [msg + "\n"] + current_len = msg_len else: # Append message to messages to send - msgs += msg + "\n" + msgs_list.append(msg + "\n") + current_len += msg_len - if msgs: - await self._send_msg(msgs, parse_mode=ParseMode.MARKDOWN_V2) + if msgs_list: + await self._send_msg("".join(msgs_list), parse_mode=ParseMode.MARKDOWN_V2) @authorized_only async def _help(self, update: Update, context: CallbackContext) -> None: diff --git a/tests/data/test_converter.py b/tests/data/test_converter.py index 835f5a861..43a52e860 100644 --- a/tests/data/test_converter.py +++ b/tests/data/test_converter.py @@ -198,10 +198,16 @@ def test_ohlcv_fill_up_missing_data2(caplog): ) def test_ohlcv_to_dataframe_multi(timeframe): data = generate_test_data(timeframe, 180) + # Convert DataFrame to list of lists (simulating ccxt output) + # Date needs to be converted to int64 ms timestamp + ohlcv_data = data.copy() + ohlcv_data["date"] = ohlcv_data["date"].astype(np.int64) // 1000 // 1000 + ohlcv_list = ohlcv_data.values.tolist() + assert len(data) == 180 - df = ohlcv_to_dataframe(data, timeframe, "UNITTEST/USDT") + df = ohlcv_to_dataframe(ohlcv_list, timeframe, "UNITTEST/USDT") assert len(df) == len(data) - 1 - df1 = ohlcv_to_dataframe(data, timeframe, "UNITTEST/USDT", drop_incomplete=False) + df1 = ohlcv_to_dataframe(ohlcv_list, timeframe, "UNITTEST/USDT", drop_incomplete=False) assert len(df1) == len(data) assert data.equals(df1) @@ -211,7 +217,13 @@ def test_ohlcv_to_dataframe_multi(timeframe): else: # Shift by half a timeframe data1.loc[:, "date"] = data1.loc[:, "date"] + (pd.to_timedelta(timeframe) / 2) - df2 = ohlcv_to_dataframe(data1, timeframe, "UNITTEST/USDT") + + # Prepare data1 for ohlcv_to_dataframe + ohlcv_data1 = data1.copy() + ohlcv_data1["date"] = ohlcv_data1["date"].astype(np.int64) // 1000 // 1000 + ohlcv_list1 = ohlcv_data1.values.tolist() + + df2 = ohlcv_to_dataframe(ohlcv_list1, timeframe, "UNITTEST/USDT") assert len(df2) == len(data) - 1 tfs = timeframe_to_seconds(timeframe) diff --git a/tests/rpc/test_api_rate_limit.py b/tests/rpc/test_api_rate_limit.py new file mode 100644 index 000000000..062dfa4a2 --- /dev/null +++ b/tests/rpc/test_api_rate_limit.py @@ -0,0 +1,120 @@ +from unittest.mock import MagicMock + +import pytest +from fastapi.testclient import TestClient +from requests.auth import _basic_auth_str + +from freqtrade.enums import RunMode +from freqtrade.loggers import setup_logging +from freqtrade.rpc.api_server import ApiServer +from freqtrade.rpc.rpc import RPC +from tests.conftest import get_patched_freqtradebot + + +BASE_URI = "/api/v1" +_TEST_USER = "FreqTrader" +_TEST_PASS = "SuperSecurePassword1!" + + +@pytest.fixture +def botclient_ratelimit(default_conf, mocker): + setup_logging(default_conf) + default_conf["runmode"] = RunMode.DRY_RUN + default_conf.update( + { + "api_server": { + "enabled": True, + "listen_ip_address": "127.0.0.1", + "listen_port": 8080, + "username": _TEST_USER, + "password": _TEST_PASS, + "jwt_secret_key": "super-secret", + } + } + ) + + ftbot = get_patched_freqtradebot(mocker, default_conf) + rpc = RPC(ftbot) + mocker.patch("freqtrade.rpc.api_server.ApiServer.start_api", MagicMock()) + apiserver = None + + # Reset cache for each test + from freqtrade.rpc.api_server.api_auth import login_attempts_cache + + login_attempts_cache.clear() + + try: + apiserver = ApiServer(default_conf) + apiserver.add_rpc_handler(rpc) + with TestClient(apiserver.app) as client: + yield ftbot, client + finally: + if apiserver: + apiserver.cleanup() + ApiServer.shutdown() + + +def test_login_rate_limit(botclient_ratelimit): + _ftbot, client = botclient_ratelimit + + # Fail 5 times + for _ in range(5): + rc = client.post( + f"{BASE_URI}/token/login", + headers={"Authorization": _basic_auth_str(_TEST_USER, "WrongPass")}, + ) + assert rc.status_code == 401 + + # 6th attempt should be rate limited + rc = client.post( + f"{BASE_URI}/token/login", + headers={"Authorization": _basic_auth_str(_TEST_USER, "WrongPass")}, + ) + assert rc.status_code == 429 + assert "Too many login attempts" in rc.json()["detail"] + + # Even correct password should fail now + rc = client.post( + f"{BASE_URI}/token/login", + headers={"Authorization": _basic_auth_str(_TEST_USER, _TEST_PASS)}, + ) + assert rc.status_code == 429 + + +def test_login_success_resets_limit(botclient_ratelimit): + _ftbot, client = botclient_ratelimit + + # Fail 4 times + for _ in range(4): + client.post( + f"{BASE_URI}/token/login", + headers={"Authorization": _basic_auth_str(_TEST_USER, "WrongPass")}, + ) + + # Succeed + rc = client.post( + f"{BASE_URI}/token/login", + headers={"Authorization": _basic_auth_str(_TEST_USER, _TEST_PASS)}, + ) + assert rc.status_code == 200 + + # Fail 1 time (would be 5th if not reset) + rc = client.post( + f"{BASE_URI}/token/login", + headers={"Authorization": _basic_auth_str(_TEST_USER, "WrongPass")}, + ) + assert rc.status_code == 401 + + # Check if we can still try (should allow 4 more) + for _ in range(4): + client.post( + f"{BASE_URI}/token/login", + headers={"Authorization": _basic_auth_str(_TEST_USER, "WrongPass")}, + ) + + # 6th attempt (after 5 failures) + rc = client.post( + f"{BASE_URI}/token/login", + headers={"Authorization": _basic_auth_str(_TEST_USER, "WrongPass")}, + ) + assert rc.status_code == 429 diff --git a/tests/rpc/test_api_security.py b/tests/rpc/test_api_security.py new file mode 100644 index 000000000..3d26a37c4 --- /dev/null +++ b/tests/rpc/test_api_security.py @@ -0,0 +1,130 @@ +from unittest.mock import MagicMock + +import pytest +from fastapi.testclient import TestClient + +from freqtrade.enums import RunMode +from freqtrade.loggers import setup_logging +from freqtrade.rpc.api_server import ApiServer +from freqtrade.rpc.rpc import RPC +from tests.conftest import get_patched_freqtradebot + + +BASE_URI = "/api/v1" + + +@pytest.fixture +def botclient_security(default_conf, mocker): + setup_logging(default_conf) + default_conf["runmode"] = RunMode.DRY_RUN + default_conf.update( + { + "api_server": { + "enabled": True, + "listen_ip_address": "127.0.0.1", + "listen_port": 8080, + "username": "user", + "password": "password", + "jwt_secret_key": "super-secret", + "CORS_origins": ["http://example.com"], + } + } + ) + + ftbot = get_patched_freqtradebot(mocker, default_conf) + rpc = RPC(ftbot) + mocker.patch("freqtrade.rpc.api_server.ApiServer.start_api", MagicMock()) + apiserver = None + try: + apiserver = ApiServer(default_conf) + apiserver.add_rpc_handler(rpc) + with TestClient(apiserver.app, raise_server_exceptions=False) as client: + yield ftbot, client + finally: + if apiserver: + apiserver.cleanup() + ApiServer.shutdown() + + +def test_security_headers(botclient_security): + _ftbot, client = botclient_security + + rc = client.get(f"{BASE_URI}/ping") + assert rc.status_code == 200 + headers = rc.headers + + assert ( + headers["Content-Security-Policy"] + == "default-src 'self'; style-src 'self' 'unsafe-inline'; " + "script-src 'self' 'unsafe-inline'; img-src 'self' data:;" + ) + assert headers["X-Content-Type-Options"] == "nosniff" + assert headers["X-Frame-Options"] == "DENY" + assert headers["Strict-Transport-Security"] == "max-age=63072000; includeSubDomains" + + +def test_cors_restrictions(botclient_security): + _ftbot, client = botclient_security + + # Preflight for GET (allowed) + rc = client.options( + f"{BASE_URI}/ping", + headers={ + "Origin": "http://example.com", + "Access-Control-Request-Method": "GET", + }, + ) + assert rc.status_code == 200 + assert "access-control-allow-methods" in rc.headers + assert "GET" in rc.headers["access-control-allow-methods"] + + # Preflight for TRACE (not allowed) + rc = client.options( + f"{BASE_URI}/ping", + headers={ + "Origin": "http://example.com", + "Access-Control-Request-Method": "TRACE", + }, + ) + # It might return 200 but allow methods shouldn't have TRACE + if "access-control-allow-methods" in rc.headers: + assert "TRACE" not in rc.headers["access-control-allow-methods"] + + +def test_generic_exception_handling(botclient_security, mocker): + _ftbot, client = botclient_security + + # Patch RPC._rpc_show_config to raise exception + mocker.patch( + "freqtrade.rpc.rpc.RPC._rpc_show_config", side_effect=Exception("Secret Stack Trace") + ) + + from requests.auth import _basic_auth_str + + rc = client.get( + f"{BASE_URI}/show_config", headers={"Authorization": _basic_auth_str("user", "password")} + ) + assert rc.status_code == 500 + assert rc.json() == {"error": "Internal Server Error"} + # The stack trace should NOT be in the response + assert "Secret Stack Trace" not in rc.text + + +def test_pair_validation(botclient_security): + _ftbot, client = botclient_security + from requests.auth import _basic_auth_str + + headers = {"Authorization": _basic_auth_str("user", "password")} + + # Valid pair + rc = client.get(f"{BASE_URI}/entries?pair=XRP/BTC", headers=headers) + assert rc.status_code == 200 + + # Invalid pair (injection attempt) + rc = client.get(f"{BASE_URI}/entries?pair=XRP/BTC;DROP%20TABLE", headers=headers) + assert rc.status_code == 422 + assert rc.json()["detail"][0]["msg"] == "String should match pattern '^[a-zA-Z0-9/_:]+$'" + + # Valid pair with numbers and : + rc = client.get(f"{BASE_URI}/entries?pair=XRP/USDT:USDT", headers=headers) + assert rc.status_code == 200 diff --git a/tests/rpc/test_rpc_apiserver.py b/tests/rpc/test_rpc_apiserver.py index 4608a370a..3aed772de 100644 --- a/tests/rpc/test_rpc_apiserver.py +++ b/tests/rpc/test_rpc_apiserver.py @@ -78,11 +78,15 @@ def botclient(default_conf, mocker): try: apiserver = ApiServer(default_conf) apiserver.add_rpc_handler(rpc) + + from freqtrade.rpc.api_server.api_auth import login_attempts_cache + + login_attempts_cache.clear() + # We need to use the TestClient as a context manager to # handle lifespan events correctly with TestClient(apiserver.app) as client: yield ftbot, client - # Cleanup ... ? finally: if apiserver: apiserver.cleanup()