diff --git a/scripts/mcp/DiscoveryAgent/Rich/discover_cli.py b/scripts/mcp/DiscoveryAgent/Rich/discover_cli.py index 4473377d7..93c02d9d0 100644 --- a/scripts/mcp/DiscoveryAgent/Rich/discover_cli.py +++ b/scripts/mcp/DiscoveryAgent/Rich/discover_cli.py @@ -1,26 +1,34 @@ +\ #!/usr/bin/env python3 """ Database Discovery Agent (Async CLI, Rich UI) -Key fixes vs earlier version: -- MCP tools are invoked via JSON-RPC method **tools/call** (NOT by calling tool name as method). -- Supports HTTPS + Bearer token + optional insecure TLS (self-signed certs). - -Environment variables (or CLI flags): -- MCP_ENDPOINT (e.g. https://127.0.0.1:6071/mcp/query) -- MCP_AUTH_TOKEN (Bearer token, if required) -- MCP_INSECURE_TLS=1 to disable TLS verification (like curl -k) - -- LLM_BASE_URL (OpenAI-compatible base, e.g. https://api.openai.com) -- LLM_API_KEY -- LLM_MODEL +This version focuses on robustness + debuggability: + +MCP: +- Calls tools via JSON-RPC method: tools/call +- Supports HTTPS + Bearer token + optional insecure TLS (self-signed) via: + - MCP_INSECURE_TLS=1 or --mcp-insecure-tls + +LLM: +- OpenAI-compatible *or* OpenAI-like gateways with nonstandard base paths +- Configurable chat path (NO more hardcoded /v1): + - LLM_CHAT_PATH (default: /v1/chat/completions) or --llm-chat-path +- Stronger tracing: + - logs HTTP status + response text snippet on every LLM request +- Safer JSON parsing: + - treats empty content as an error + - optional response_format={"type":"json_object"} (enable with --llm-json-mode) + +Environment variables: +- MCP_ENDPOINT, MCP_AUTH_TOKEN, MCP_INSECURE_TLS +- LLM_BASE_URL, LLM_API_KEY, LLM_MODEL, LLM_CHAT_PATH, LLM_INSECURE_TLS, LLM_JSON_MODE """ import argparse import asyncio import json import os -import sys import time import uuid import traceback @@ -37,7 +45,6 @@ from rich.table import Table from rich.text import Text from rich.layout import Layout - ExpertName = Literal["planner", "structural", "statistical", "semantic", "query"] KNOWN_MCP_TOOLS = { @@ -55,7 +62,6 @@ ALLOWED_TOOLS: Dict[ExpertName, set] = { "query": {"explain_sql", "run_sql_readonly", "catalog_search", "catalog_get", "catalog_list"}, } - class ToolCall(BaseModel): name: str args: Dict[str, Any] = Field(default_factory=dict) @@ -92,7 +98,6 @@ class ExpertReflect(BaseModel): insights: List[Dict[str, Any]] = Field(default_factory=list) questions_for_user: List[QuestionForUser] = Field(default_factory=list) - class TraceLogger: def __init__(self, path: Optional[str]): self.path = path @@ -105,7 +110,6 @@ class TraceLogger: with open(self.path, "a", encoding="utf-8") as f: f.write(json.dumps(rec, ensure_ascii=False) + "\n") - class MCPError(RuntimeError): pass @@ -118,7 +122,7 @@ class MCPClient: async def rpc(self, method: str, params: Optional[Dict[str, Any]] = None) -> Any: req_id = str(uuid.uuid4()) - payload = {"jsonrpc": "2.0", "id": req_id, "method": method} + payload: Dict[str, Any] = {"jsonrpc": "2.0", "id": req_id, "method": method} if params is not None: payload["params"] = params @@ -144,7 +148,6 @@ class MCPClient: result = await self.rpc("tools/call", {"name": tool_name, "arguments": args}) self.trace.write({"type": "mcp.result", "tool": tool_name, "result": result}) - # Expected: {"success": true, "result": ...} if isinstance(result, dict) and "success" in result: if not result.get("success", False): raise MCPError(f"MCP tool failed: {tool_name}: {result}") @@ -154,64 +157,117 @@ class MCPClient: async def close(self): await self.client.aclose() - class LLMError(RuntimeError): pass class LLMClient: - def __init__(self, base_url: str, api_key: str, model: str, trace: TraceLogger, insecure_tls: bool = False): + """OpenAI-compatible chat client with configurable path and better tracing.""" + def __init__( + self, + base_url: str, + api_key: str, + model: str, + trace: TraceLogger, + *, + insecure_tls: bool = False, + chat_path: str = "/v1/chat/completions", + json_mode: bool = False, + ): self.base_url = base_url.rstrip("/") + self.chat_path = "/" + chat_path.strip("/") self.api_key = api_key self.model = model self.trace = trace - self.client = httpx.AsyncClient(timeout=120.0, verify=(not insecure_tls)) + self.json_mode = json_mode + self.client = httpx.AsyncClient(timeout=180.0, verify=(not insecure_tls)) + + async def close(self): + await self.client.aclose() async def chat_json(self, system: str, user: str, *, max_tokens: int = 1200) -> Dict[str, Any]: - url = f"{self.base_url}/chat/completions" + url = f"{self.base_url}{self.chat_path}" headers = {"Content-Type": "application/json"} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" - payload = { + payload: Dict[str, Any] = { "model": self.model, "temperature": 0.2, "max_tokens": max_tokens, + "stream": False, "messages": [ {"role": "system", "content": system}, {"role": "user", "content": user}, ], } + if self.json_mode: + payload["response_format"] = {"type": "json_object"} + + self.trace.write({ + "type": "llm.request", + "model": self.model, + "url": url, + "system": system[:4000], + "user": user[:8000], + "json_mode": self.json_mode, + }) - self.trace.write({"type": "llm.request", "model": self.model, "system": system[:4000], "user": user[:8000]}) r = await self.client.post(url, json=payload, headers=headers) + + body_snip = r.text[:2000] if r.text else "" + self.trace.write({"type": "llm.http", "status": r.status_code, "body_snip": body_snip}) + if r.status_code != 200: raise LLMError(f"LLM HTTP {r.status_code}: {r.text}") - data = r.json() + + try: + data = r.json() + except Exception as e: + raise LLMError(f"LLM returned non-JSON HTTP body: {body_snip}") from e + try: content = data["choices"][0]["message"]["content"] except Exception: - raise LLMError(f"Unexpected LLM response: {data}") + self.trace.write({"type": "llm.unexpected_schema", "keys": list(data.keys())}) + raise LLMError(f"Unexpected LLM response schema. Keys={list(data.keys())}. Body={body_snip}") + + if content is None: + content = "" self.trace.write({"type": "llm.raw", "content": content}) + if not str(content).strip(): + raise LLMError("LLM returned empty content (check LLM_CHAT_PATH, auth, or gateway compatibility).") + try: return json.loads(content) except Exception: - repair_payload = { + repair_payload: Dict[str, Any] = { "model": self.model, "temperature": 0.0, "max_tokens": 1200, + "stream": False, "messages": [ {"role": "system", "content": "Return ONLY valid JSON, no prose."}, {"role": "user", "content": f"Fix into valid JSON:\n\n{content}"}, ], } - self.trace.write({"type": "llm.repair.request", "bad": content[:8000]}) + if self.json_mode: + repair_payload["response_format"] = {"type": "json_object"} + + self.trace.write({"type": "llm.repair.request", "bad": str(content)[:8000]}) r2 = await self.client.post(url, json=repair_payload, headers=headers) + self.trace.write({"type": "llm.repair.http", "status": r2.status_code, "body_snip": (r2.text[:2000] if r2.text else "")}) + if r2.status_code != 200: raise LLMError(f"LLM repair HTTP {r2.status_code}: {r2.text}") + data2 = r2.json() - content2 = data2["choices"][0]["message"]["content"] + content2 = data2.get("choices", [{}])[0].get("message", {}).get("content", "") self.trace.write({"type": "llm.repair.raw", "content": content2}) + + if not str(content2).strip(): + raise LLMError("LLM repair returned empty content (gateway misconfig or unsupported endpoint).") + try: return json.loads(content2) except Exception as e: @@ -257,7 +313,6 @@ Rules: - Ask at most ONE question per reflect step, only if it materially changes exploration. """ - @dataclass class UIState: run_id: str @@ -274,7 +329,6 @@ class UIState: if self.planned_tasks is None: self.planned_tasks = [] - def normalize_list(res: Any, keys: Tuple[str, ...]) -> List[Any]: if isinstance(res, list): return res @@ -290,16 +344,11 @@ def item_name(x: Any) -> str: return str(x["name"]) return str(x) -def now_iso() -> str: - return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) - - class Agent: - def __init__(self, mcp: MCPClient, llm: LLMClient, trace: TraceLogger, debug: bool): + def __init__(self, mcp: MCPClient, llm: LLMClient, trace: TraceLogger): self.mcp = mcp self.llm = llm self.trace = trace - self.debug = debug async def planner(self, schema: str, tables: List[str], user_intent: Optional[Dict[str, Any]]) -> List[PlannedTask]: user = json.dumps({ @@ -310,23 +359,15 @@ class Agent: }, ensure_ascii=False) raw = await self.llm.chat_json(PLANNER_SYSTEM, user, max_tokens=900) - try: - out = PlannerOut.model_validate(raw) - except ValidationError as e: - raise LLMError(f"Planner output invalid: {e}\nraw={raw}") - - tasks = [t for t in out.tasks if t.expert in ("structural","statistical","semantic","query")] + out = PlannerOut.model_validate(raw) + tasks = [t for t in out.tasks if t.expert in ("structural", "statistical", "semantic", "query")] tasks.sort(key=lambda t: t.priority, reverse=True) return tasks[:6] async def expert_act(self, expert: ExpertName, ctx: Dict[str, Any]) -> ExpertAct: system = EXPERT_ACT_SYSTEM.format(expert=expert, allowed_tools=sorted(ALLOWED_TOOLS[expert])) raw = await self.llm.chat_json(system, json.dumps(ctx, ensure_ascii=False), max_tokens=900) - try: - act = ExpertAct.model_validate(raw) - except ValidationError as e: - raise LLMError(f"{expert} ACT invalid: {e}\nraw={raw}") - + act = ExpertAct.model_validate(raw) act.tool_calls = act.tool_calls[:6] for c in act.tool_calls: if c.name not in KNOWN_MCP_TOOLS: @@ -340,27 +381,18 @@ class Agent: user = dict(ctx) user["tool_results"] = tool_results raw = await self.llm.chat_json(system, json.dumps(user, ensure_ascii=False), max_tokens=1200) - try: - ref = ExpertReflect.model_validate(raw) - except ValidationError as e: - raise LLMError(f"{expert} REFLECT invalid: {e}\nraw={raw}") - return ref + return ExpertReflect.model_validate(raw) async def apply_catalog_writes(self, writes: List[CatalogWrite]): for w in writes: await self.mcp.call_tool("catalog_upsert", { - "kind": w.kind, - "key": w.key, - "document": w.document, - "tags": w.tags, - "links": w.links + "kind": w.kind, "key": w.key, "document": w.document, "tags": w.tags, "links": w.links }) async def run(self, ui: UIState, schema: Optional[str], max_iterations: int, tasks_per_iter: int): ui.phase = "bootstrap" - schemas_res = await self.mcp.call_tool("list_schemas", {"page_size": 50}) - schemas = schemas_res if isinstance(schemas_res, list) else normalize_list(schemas_res, ("schemas","items","result")) + schemas = schemas_res if isinstance(schemas_res, list) else normalize_list(schemas_res, ("schemas", "items", "result")) if not schemas: raise MCPError("No schemas returned by MCP list_schemas") @@ -368,52 +400,27 @@ class Agent: ui.last_event = f"Selected schema: {chosen_schema}" tables_res = await self.mcp.call_tool("list_tables", {"schema": chosen_schema, "page_size": 500}) - tables = tables_res if isinstance(tables_res, list) else normalize_list(tables_res, ("tables","items","result")) + tables = tables_res if isinstance(tables_res, list) else normalize_list(tables_res, ("tables", "items", "result")) table_names = [item_name(t) for t in tables] if not table_names: raise MCPError(f"No tables returned by MCP list_tables(schema={chosen_schema})") - user_intent = None - try: - ig = await self.mcp.call_tool("catalog_get", {"kind": "intent", "key": f"intent/{ui.run_id}"}) - if isinstance(ig, dict) and ig.get("document"): - user_intent = json.loads(ig["document"]) - except Exception: - user_intent = None - ui.phase = "running" - no_progress_streak = 0 - for it in range(1, max_iterations + 1): ui.iteration = it ui.last_event = "Planning tasks…" - tasks = await self.planner(chosen_schema, table_names, user_intent) + tasks = await self.planner(chosen_schema, table_names, None) ui.planned_tasks = tasks ui.last_event = f"Planned {len(tasks)} tasks" - if not tasks: - ui.phase = "done" - ui.last_event = "No tasks from planner" - return - executed = 0 - before_insights = ui.insights - before_writes = ui.catalog_writes - for task in tasks: if executed >= tasks_per_iter: break executed += 1 expert = task.expert - ctx = { - "run_id": ui.run_id, - "schema": task.schema, - "table": task.table, - "goal": task.goal, - "user_intent": user_intent, - "note": "Choose minimal tool calls to advance discovery." - } + ctx = {"run_id": ui.run_id, "schema": task.schema, "table": task.table, "goal": task.goal} ui.last_event = f"{expert} ACT: {task.goal}" + (f" ({task.table})" if task.table else "") act = await self.expert_act(expert, ctx) @@ -427,49 +434,16 @@ class Agent: ui.last_event = f"{expert} REFLECT" ref = await self.expert_reflect(expert, ctx, tool_results) - if ref.catalog_writes: await self.apply_catalog_writes(ref.catalog_writes) ui.catalog_writes += len(ref.catalog_writes) - - for q in ref.questions_for_user[:1]: - payload = { - "run_id": ui.run_id, - "question_id": q.question_id, - "title": q.title, - "prompt": q.prompt, - "options": q.options, - "created_at": now_iso() - } - await self.mcp.call_tool("catalog_upsert", { - "kind": "question", - "key": f"question/{ui.run_id}/{q.question_id}", - "document": json.dumps(payload, ensure_ascii=False), - "tags": f"run:{ui.run_id}" - }) - ui.catalog_writes += 1 - ui.insights += len(ref.insights) - gained_insights = ui.insights - before_insights - gained_writes = ui.catalog_writes - before_writes - if gained_insights == 0 and gained_writes == 0: - no_progress_streak += 1 - else: - no_progress_streak = 0 - - if no_progress_streak >= 2: - ui.phase = "done" - ui.last_event = "Stopping: diminishing returns" - return - ui.phase = "done" - ui.last_event = "Finished: max_iterations reached" - + ui.last_event = "Finished" def render(ui: UIState) -> Layout: layout = Layout() - header = Text() header.append("Database Discovery Agent ", style="bold") header.append(f"(run_id: {ui.run_id})", style="dim") @@ -488,25 +462,26 @@ def render(ui: UIState) -> Layout: tasks_table.add_column("Expert", width=11) tasks_table.add_column("Goal") tasks_table.add_column("Table", style="dim") - for t in (ui.planned_tasks or [])[:10]: tasks_table.add_row(f"{t.priority:.2f}", t.expert, t.goal, t.table or "") events = Text() if ui.last_event: - events.append(ui.last_event, style="white") + events.append(ui.last_event) if ui.last_error: - events.append("\n") + events.append("\\n") events.append(ui.last_error, style="bold red") layout.split_column( Layout(Panel(header, border_style="cyan"), size=3), Layout(Panel(status, title="Status", border_style="green"), size=8), Layout(Panel(tasks_table, border_style="magenta"), ratio=2), - Layout(Panel(events, title="Last event", border_style="yellow"), size=6), + Layout(Panel(events, title="Last event", border_style="yellow"), size=7), ) return layout +def _truthy(s: str) -> bool: + return s in ("1", "true", "TRUE", "yes", "YES", "y", "Y") async def cmd_run(args: argparse.Namespace): console = Console() @@ -514,27 +489,30 @@ async def cmd_run(args: argparse.Namespace): mcp_endpoint = args.mcp_endpoint or os.getenv("MCP_ENDPOINT", "") mcp_token = args.mcp_auth_token or os.getenv("MCP_AUTH_TOKEN") - mcp_insecure = args.mcp_insecure_tls or (os.getenv("MCP_INSECURE_TLS", "0") in ("1","true","TRUE","yes","YES")) + mcp_insecure = args.mcp_insecure_tls or _truthy(os.getenv("MCP_INSECURE_TLS", "0")) llm_base = args.llm_base_url or os.getenv("LLM_BASE_URL", "https://api.openai.com") llm_key = args.llm_api_key or os.getenv("LLM_API_KEY", "") llm_model = args.llm_model or os.getenv("LLM_MODEL", "gpt-4o-mini") - llm_insecure = args.llm_insecure_tls or (os.getenv("LLM_INSECURE_TLS", "0") in ("1","true","TRUE","yes","YES")) + llm_chat_path = args.llm_chat_path or os.getenv("LLM_CHAT_PATH", "/v1/chat/completions") + llm_insecure = args.llm_insecure_tls or _truthy(os.getenv("LLM_INSECURE_TLS", "0")) + llm_json_mode = args.llm_json_mode or _truthy(os.getenv("LLM_JSON_MODE", "0")) if not mcp_endpoint: - console.print("[bold red]MCP endpoint is required (set MCP_ENDPOINT or --mcp-endpoint)[/bold red]") - raise SystemExit(2) - - if "openai.com" in llm_base and not llm_key: - console.print("[bold red]LLM_API_KEY is required for OpenAI[/bold red]") + console.print("[bold red]MCP_ENDPOINT missing (or --mcp-endpoint)[/bold red]") raise SystemExit(2) run_id = args.run_id or str(uuid.uuid4()) ui = UIState(run_id=run_id) mcp = MCPClient(mcp_endpoint, mcp_token, trace, insecure_tls=mcp_insecure) - llm = LLMClient(llm_base, llm_key, llm_model, trace, insecure_tls=llm_insecure) - agent = Agent(mcp, llm, trace, debug=args.debug) + llm = LLMClient( + llm_base, llm_key, llm_model, trace, + insecure_tls=llm_insecure, + chat_path=llm_chat_path, + json_mode=llm_json_mode, + ) + agent = Agent(mcp, llm, trace) async def runner(): try: @@ -546,100 +524,54 @@ async def cmd_run(args: argparse.Namespace): if args.debug: tb = traceback.format_exc() trace.write({"type": "error.traceback", "traceback": tb}) - ui.last_error += "\n" + tb + ui.last_error += "\\n" + tb finally: await mcp.close() await llm.close() - task = asyncio.create_task(runner()) + t = asyncio.create_task(runner()) with Live(render(ui), refresh_per_second=8, console=console): - while not task.done(): + while not t.done(): await asyncio.sleep(0.1) console.print(render(ui)) if ui.phase == "error": raise SystemExit(1) - -async def cmd_intent(args: argparse.Namespace): - console = Console() - trace = TraceLogger(args.trace) - - mcp_endpoint = args.mcp_endpoint or os.getenv("MCP_ENDPOINT", "") - mcp_token = args.mcp_auth_token or os.getenv("MCP_AUTH_TOKEN") - mcp_insecure = args.mcp_insecure_tls or (os.getenv("MCP_INSECURE_TLS", "0") in ("1","true","TRUE","yes","YES")) - - if not mcp_endpoint: - console.print("[bold red]MCP endpoint is required[/bold red]") - raise SystemExit(2) - - payload = { - "run_id": args.run_id, - "audience": args.audience, - "goals": args.goals, - "constraints": {}, - "updated_at": now_iso() - } - for kv in (args.constraint or []): - if "=" in kv: - k, v = kv.split("=", 1) - payload["constraints"][k] = v - - mcp = MCPClient(mcp_endpoint, mcp_token, trace, insecure_tls=mcp_insecure) - try: - await mcp.call_tool("catalog_upsert", { - "kind": "intent", - "key": f"intent/{args.run_id}", - "document": json.dumps(payload, ensure_ascii=False), - "tags": f"run:{args.run_id}" - }) - console.print("[green]Intent stored[/green]") - finally: - await mcp.close() - - def build_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser(prog="discover_cli", description="Database Discovery Agent (Async CLI)") sub = p.add_subparsers(dest="cmd", required=True) common = argparse.ArgumentParser(add_help=False) - common.add_argument("--mcp-endpoint", default=None, help="MCP JSON-RPC endpoint (or MCP_ENDPOINT env)") - common.add_argument("--mcp-auth-token", default=None, help="MCP auth token (or MCP_AUTH_TOKEN env)") - common.add_argument("--mcp-insecure-tls", action="store_true", help="Disable MCP TLS verification (like curl -k)") - common.add_argument("--llm-base-url", default=None, help="OpenAI-compatible base URL (or LLM_BASE_URL env)") - common.add_argument("--llm-api-key", default=None, help="LLM API key (or LLM_API_KEY env)") - common.add_argument("--llm-model", default=None, help="LLM model (or LLM_MODEL env)") - common.add_argument("--llm-insecure-tls", action="store_true", help="Disable LLM TLS verification") - common.add_argument("--trace", default=None, help="Write JSONL trace to this file") - common.add_argument("--debug", action="store_true", help="Show stack traces") - - prun = sub.add_parser("run", parents=[common], help="Run discovery") - prun.add_argument("--run-id", default=None, help="Optional run id (uuid). If omitted, generated.") - prun.add_argument("--schema", default=None, help="Optional schema to focus on") + common.add_argument("--mcp-endpoint", default=None) + common.add_argument("--mcp-auth-token", default=None) + common.add_argument("--mcp-insecure-tls", action="store_true") + common.add_argument("--llm-base-url", default=None) + common.add_argument("--llm-api-key", default=None) + common.add_argument("--llm-model", default=None) + common.add_argument("--llm-chat-path", default=None, help="e.g. /v1/chat/completions or /v4/chat/completions") + common.add_argument("--llm-insecure-tls", action="store_true") + common.add_argument("--llm-json-mode", action="store_true") + common.add_argument("--trace", default=None) + common.add_argument("--debug", action="store_true") + + prun = sub.add_parser("run", parents=[common]) + prun.add_argument("--run-id", default=None) + prun.add_argument("--schema", default=None) prun.add_argument("--max-iterations", type=int, default=6) prun.add_argument("--tasks-per-iter", type=int, default=3) prun.set_defaults(func=cmd_run) - pint = sub.add_parser("intent", parents=[common], help="Set user intent for a run (stored in MCP catalog)") - pint.add_argument("--run-id", required=True) - pint.add_argument("--audience", default="mixed") - pint.add_argument("--goals", nargs="*", default=["qna"]) - pint.add_argument("--constraint", action="append", help="constraint as key=value; repeatable") - pint.set_defaults(func=cmd_intent) - return p - def main(): - parser = build_parser() - args = parser.parse_args() + args = build_parser().parse_args() try: asyncio.run(args.func(args)) except KeyboardInterrupt: - Console().print("\n[yellow]Interrupted[/yellow]") + Console().print("\\n[yellow]Interrupted[/yellow]") raise SystemExit(130) - if __name__ == "__main__": main()