Enhance Rich CLI with configurable LLM chat path and better tracing

LLM improvements:
- Add configurable chat path (LLM_CHAT_PATH or --llm-chat-path) to support
  non-standard endpoints like Z.ai's /api/coding/paas/v4
- Add optional JSON mode (LLM_JSON_MODE or --llm-json-mode) for models
  that support native JSON output
- Enhanced tracing: log HTTP status and response body snippet on every request
- Safer JSON parsing: treat empty content as error with helpful message
- Better error messages with diagnostic hints

Code cleanup:
- Remove intent command (simplify CLI)
- Remove user_intent reading and passing
- Simplify stopping logic (just run max_iterations)
- Clean up formatting and remove unused code
pull/5310/head
Rene Cannao 4 months ago
parent f2ca750c05
commit 9d6a2173bf

@ -1,26 +1,34 @@
\
#!/usr/bin/env python3
"""
Database Discovery Agent (Async CLI, Rich UI)
Key fixes vs earlier version:
- MCP tools are invoked via JSON-RPC method **tools/call** (NOT by calling tool name as method).
- Supports HTTPS + Bearer token + optional insecure TLS (self-signed certs).
Environment variables (or CLI flags):
- MCP_ENDPOINT (e.g. https://127.0.0.1:6071/mcp/query)
- MCP_AUTH_TOKEN (Bearer token, if required)
- MCP_INSECURE_TLS=1 to disable TLS verification (like curl -k)
- LLM_BASE_URL (OpenAI-compatible base, e.g. https://api.openai.com)
- LLM_API_KEY
- LLM_MODEL
This version focuses on robustness + debuggability:
MCP:
- Calls tools via JSON-RPC method: tools/call
- Supports HTTPS + Bearer token + optional insecure TLS (self-signed) via:
- MCP_INSECURE_TLS=1 or --mcp-insecure-tls
LLM:
- OpenAI-compatible *or* OpenAI-like gateways with nonstandard base paths
- Configurable chat path (NO more hardcoded /v1):
- LLM_CHAT_PATH (default: /v1/chat/completions) or --llm-chat-path
- Stronger tracing:
- logs HTTP status + response text snippet on every LLM request
- Safer JSON parsing:
- treats empty content as an error
- optional response_format={"type":"json_object"} (enable with --llm-json-mode)
Environment variables:
- MCP_ENDPOINT, MCP_AUTH_TOKEN, MCP_INSECURE_TLS
- LLM_BASE_URL, LLM_API_KEY, LLM_MODEL, LLM_CHAT_PATH, LLM_INSECURE_TLS, LLM_JSON_MODE
"""
import argparse
import asyncio
import json
import os
import sys
import time
import uuid
import traceback
@ -37,7 +45,6 @@ from rich.table import Table
from rich.text import Text
from rich.layout import Layout
ExpertName = Literal["planner", "structural", "statistical", "semantic", "query"]
KNOWN_MCP_TOOLS = {
@ -55,7 +62,6 @@ ALLOWED_TOOLS: Dict[ExpertName, set] = {
"query": {"explain_sql", "run_sql_readonly", "catalog_search", "catalog_get", "catalog_list"},
}
class ToolCall(BaseModel):
name: str
args: Dict[str, Any] = Field(default_factory=dict)
@ -92,7 +98,6 @@ class ExpertReflect(BaseModel):
insights: List[Dict[str, Any]] = Field(default_factory=list)
questions_for_user: List[QuestionForUser] = Field(default_factory=list)
class TraceLogger:
def __init__(self, path: Optional[str]):
self.path = path
@ -105,7 +110,6 @@ class TraceLogger:
with open(self.path, "a", encoding="utf-8") as f:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
class MCPError(RuntimeError):
pass
@ -118,7 +122,7 @@ class MCPClient:
async def rpc(self, method: str, params: Optional[Dict[str, Any]] = None) -> Any:
req_id = str(uuid.uuid4())
payload = {"jsonrpc": "2.0", "id": req_id, "method": method}
payload: Dict[str, Any] = {"jsonrpc": "2.0", "id": req_id, "method": method}
if params is not None:
payload["params"] = params
@ -144,7 +148,6 @@ class MCPClient:
result = await self.rpc("tools/call", {"name": tool_name, "arguments": args})
self.trace.write({"type": "mcp.result", "tool": tool_name, "result": result})
# Expected: {"success": true, "result": ...}
if isinstance(result, dict) and "success" in result:
if not result.get("success", False):
raise MCPError(f"MCP tool failed: {tool_name}: {result}")
@ -154,64 +157,117 @@ class MCPClient:
async def close(self):
await self.client.aclose()
class LLMError(RuntimeError):
pass
class LLMClient:
def __init__(self, base_url: str, api_key: str, model: str, trace: TraceLogger, insecure_tls: bool = False):
"""OpenAI-compatible chat client with configurable path and better tracing."""
def __init__(
self,
base_url: str,
api_key: str,
model: str,
trace: TraceLogger,
*,
insecure_tls: bool = False,
chat_path: str = "/v1/chat/completions",
json_mode: bool = False,
):
self.base_url = base_url.rstrip("/")
self.chat_path = "/" + chat_path.strip("/")
self.api_key = api_key
self.model = model
self.trace = trace
self.client = httpx.AsyncClient(timeout=120.0, verify=(not insecure_tls))
self.json_mode = json_mode
self.client = httpx.AsyncClient(timeout=180.0, verify=(not insecure_tls))
async def close(self):
await self.client.aclose()
async def chat_json(self, system: str, user: str, *, max_tokens: int = 1200) -> Dict[str, Any]:
url = f"{self.base_url}/chat/completions"
url = f"{self.base_url}{self.chat_path}"
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
payload: Dict[str, Any] = {
"model": self.model,
"temperature": 0.2,
"max_tokens": max_tokens,
"stream": False,
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
],
}
if self.json_mode:
payload["response_format"] = {"type": "json_object"}
self.trace.write({
"type": "llm.request",
"model": self.model,
"url": url,
"system": system[:4000],
"user": user[:8000],
"json_mode": self.json_mode,
})
self.trace.write({"type": "llm.request", "model": self.model, "system": system[:4000], "user": user[:8000]})
r = await self.client.post(url, json=payload, headers=headers)
body_snip = r.text[:2000] if r.text else ""
self.trace.write({"type": "llm.http", "status": r.status_code, "body_snip": body_snip})
if r.status_code != 200:
raise LLMError(f"LLM HTTP {r.status_code}: {r.text}")
data = r.json()
try:
data = r.json()
except Exception as e:
raise LLMError(f"LLM returned non-JSON HTTP body: {body_snip}") from e
try:
content = data["choices"][0]["message"]["content"]
except Exception:
raise LLMError(f"Unexpected LLM response: {data}")
self.trace.write({"type": "llm.unexpected_schema", "keys": list(data.keys())})
raise LLMError(f"Unexpected LLM response schema. Keys={list(data.keys())}. Body={body_snip}")
if content is None:
content = ""
self.trace.write({"type": "llm.raw", "content": content})
if not str(content).strip():
raise LLMError("LLM returned empty content (check LLM_CHAT_PATH, auth, or gateway compatibility).")
try:
return json.loads(content)
except Exception:
repair_payload = {
repair_payload: Dict[str, Any] = {
"model": self.model,
"temperature": 0.0,
"max_tokens": 1200,
"stream": False,
"messages": [
{"role": "system", "content": "Return ONLY valid JSON, no prose."},
{"role": "user", "content": f"Fix into valid JSON:\n\n{content}"},
],
}
self.trace.write({"type": "llm.repair.request", "bad": content[:8000]})
if self.json_mode:
repair_payload["response_format"] = {"type": "json_object"}
self.trace.write({"type": "llm.repair.request", "bad": str(content)[:8000]})
r2 = await self.client.post(url, json=repair_payload, headers=headers)
self.trace.write({"type": "llm.repair.http", "status": r2.status_code, "body_snip": (r2.text[:2000] if r2.text else "")})
if r2.status_code != 200:
raise LLMError(f"LLM repair HTTP {r2.status_code}: {r2.text}")
data2 = r2.json()
content2 = data2["choices"][0]["message"]["content"]
content2 = data2.get("choices", [{}])[0].get("message", {}).get("content", "")
self.trace.write({"type": "llm.repair.raw", "content": content2})
if not str(content2).strip():
raise LLMError("LLM repair returned empty content (gateway misconfig or unsupported endpoint).")
try:
return json.loads(content2)
except Exception as e:
@ -257,7 +313,6 @@ Rules:
- Ask at most ONE question per reflect step, only if it materially changes exploration.
"""
@dataclass
class UIState:
run_id: str
@ -274,7 +329,6 @@ class UIState:
if self.planned_tasks is None:
self.planned_tasks = []
def normalize_list(res: Any, keys: Tuple[str, ...]) -> List[Any]:
if isinstance(res, list):
return res
@ -290,16 +344,11 @@ def item_name(x: Any) -> str:
return str(x["name"])
return str(x)
def now_iso() -> str:
return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
class Agent:
def __init__(self, mcp: MCPClient, llm: LLMClient, trace: TraceLogger, debug: bool):
def __init__(self, mcp: MCPClient, llm: LLMClient, trace: TraceLogger):
self.mcp = mcp
self.llm = llm
self.trace = trace
self.debug = debug
async def planner(self, schema: str, tables: List[str], user_intent: Optional[Dict[str, Any]]) -> List[PlannedTask]:
user = json.dumps({
@ -310,23 +359,15 @@ class Agent:
}, ensure_ascii=False)
raw = await self.llm.chat_json(PLANNER_SYSTEM, user, max_tokens=900)
try:
out = PlannerOut.model_validate(raw)
except ValidationError as e:
raise LLMError(f"Planner output invalid: {e}\nraw={raw}")
tasks = [t for t in out.tasks if t.expert in ("structural","statistical","semantic","query")]
out = PlannerOut.model_validate(raw)
tasks = [t for t in out.tasks if t.expert in ("structural", "statistical", "semantic", "query")]
tasks.sort(key=lambda t: t.priority, reverse=True)
return tasks[:6]
async def expert_act(self, expert: ExpertName, ctx: Dict[str, Any]) -> ExpertAct:
system = EXPERT_ACT_SYSTEM.format(expert=expert, allowed_tools=sorted(ALLOWED_TOOLS[expert]))
raw = await self.llm.chat_json(system, json.dumps(ctx, ensure_ascii=False), max_tokens=900)
try:
act = ExpertAct.model_validate(raw)
except ValidationError as e:
raise LLMError(f"{expert} ACT invalid: {e}\nraw={raw}")
act = ExpertAct.model_validate(raw)
act.tool_calls = act.tool_calls[:6]
for c in act.tool_calls:
if c.name not in KNOWN_MCP_TOOLS:
@ -340,27 +381,18 @@ class Agent:
user = dict(ctx)
user["tool_results"] = tool_results
raw = await self.llm.chat_json(system, json.dumps(user, ensure_ascii=False), max_tokens=1200)
try:
ref = ExpertReflect.model_validate(raw)
except ValidationError as e:
raise LLMError(f"{expert} REFLECT invalid: {e}\nraw={raw}")
return ref
return ExpertReflect.model_validate(raw)
async def apply_catalog_writes(self, writes: List[CatalogWrite]):
for w in writes:
await self.mcp.call_tool("catalog_upsert", {
"kind": w.kind,
"key": w.key,
"document": w.document,
"tags": w.tags,
"links": w.links
"kind": w.kind, "key": w.key, "document": w.document, "tags": w.tags, "links": w.links
})
async def run(self, ui: UIState, schema: Optional[str], max_iterations: int, tasks_per_iter: int):
ui.phase = "bootstrap"
schemas_res = await self.mcp.call_tool("list_schemas", {"page_size": 50})
schemas = schemas_res if isinstance(schemas_res, list) else normalize_list(schemas_res, ("schemas","items","result"))
schemas = schemas_res if isinstance(schemas_res, list) else normalize_list(schemas_res, ("schemas", "items", "result"))
if not schemas:
raise MCPError("No schemas returned by MCP list_schemas")
@ -368,52 +400,27 @@ class Agent:
ui.last_event = f"Selected schema: {chosen_schema}"
tables_res = await self.mcp.call_tool("list_tables", {"schema": chosen_schema, "page_size": 500})
tables = tables_res if isinstance(tables_res, list) else normalize_list(tables_res, ("tables","items","result"))
tables = tables_res if isinstance(tables_res, list) else normalize_list(tables_res, ("tables", "items", "result"))
table_names = [item_name(t) for t in tables]
if not table_names:
raise MCPError(f"No tables returned by MCP list_tables(schema={chosen_schema})")
user_intent = None
try:
ig = await self.mcp.call_tool("catalog_get", {"kind": "intent", "key": f"intent/{ui.run_id}"})
if isinstance(ig, dict) and ig.get("document"):
user_intent = json.loads(ig["document"])
except Exception:
user_intent = None
ui.phase = "running"
no_progress_streak = 0
for it in range(1, max_iterations + 1):
ui.iteration = it
ui.last_event = "Planning tasks…"
tasks = await self.planner(chosen_schema, table_names, user_intent)
tasks = await self.planner(chosen_schema, table_names, None)
ui.planned_tasks = tasks
ui.last_event = f"Planned {len(tasks)} tasks"
if not tasks:
ui.phase = "done"
ui.last_event = "No tasks from planner"
return
executed = 0
before_insights = ui.insights
before_writes = ui.catalog_writes
for task in tasks:
if executed >= tasks_per_iter:
break
executed += 1
expert = task.expert
ctx = {
"run_id": ui.run_id,
"schema": task.schema,
"table": task.table,
"goal": task.goal,
"user_intent": user_intent,
"note": "Choose minimal tool calls to advance discovery."
}
ctx = {"run_id": ui.run_id, "schema": task.schema, "table": task.table, "goal": task.goal}
ui.last_event = f"{expert} ACT: {task.goal}" + (f" ({task.table})" if task.table else "")
act = await self.expert_act(expert, ctx)
@ -427,49 +434,16 @@ class Agent:
ui.last_event = f"{expert} REFLECT"
ref = await self.expert_reflect(expert, ctx, tool_results)
if ref.catalog_writes:
await self.apply_catalog_writes(ref.catalog_writes)
ui.catalog_writes += len(ref.catalog_writes)
for q in ref.questions_for_user[:1]:
payload = {
"run_id": ui.run_id,
"question_id": q.question_id,
"title": q.title,
"prompt": q.prompt,
"options": q.options,
"created_at": now_iso()
}
await self.mcp.call_tool("catalog_upsert", {
"kind": "question",
"key": f"question/{ui.run_id}/{q.question_id}",
"document": json.dumps(payload, ensure_ascii=False),
"tags": f"run:{ui.run_id}"
})
ui.catalog_writes += 1
ui.insights += len(ref.insights)
gained_insights = ui.insights - before_insights
gained_writes = ui.catalog_writes - before_writes
if gained_insights == 0 and gained_writes == 0:
no_progress_streak += 1
else:
no_progress_streak = 0
if no_progress_streak >= 2:
ui.phase = "done"
ui.last_event = "Stopping: diminishing returns"
return
ui.phase = "done"
ui.last_event = "Finished: max_iterations reached"
ui.last_event = "Finished"
def render(ui: UIState) -> Layout:
layout = Layout()
header = Text()
header.append("Database Discovery Agent ", style="bold")
header.append(f"(run_id: {ui.run_id})", style="dim")
@ -488,25 +462,26 @@ def render(ui: UIState) -> Layout:
tasks_table.add_column("Expert", width=11)
tasks_table.add_column("Goal")
tasks_table.add_column("Table", style="dim")
for t in (ui.planned_tasks or [])[:10]:
tasks_table.add_row(f"{t.priority:.2f}", t.expert, t.goal, t.table or "")
events = Text()
if ui.last_event:
events.append(ui.last_event, style="white")
events.append(ui.last_event)
if ui.last_error:
events.append("\n")
events.append("\\n")
events.append(ui.last_error, style="bold red")
layout.split_column(
Layout(Panel(header, border_style="cyan"), size=3),
Layout(Panel(status, title="Status", border_style="green"), size=8),
Layout(Panel(tasks_table, border_style="magenta"), ratio=2),
Layout(Panel(events, title="Last event", border_style="yellow"), size=6),
Layout(Panel(events, title="Last event", border_style="yellow"), size=7),
)
return layout
def _truthy(s: str) -> bool:
return s in ("1", "true", "TRUE", "yes", "YES", "y", "Y")
async def cmd_run(args: argparse.Namespace):
console = Console()
@ -514,27 +489,30 @@ async def cmd_run(args: argparse.Namespace):
mcp_endpoint = args.mcp_endpoint or os.getenv("MCP_ENDPOINT", "")
mcp_token = args.mcp_auth_token or os.getenv("MCP_AUTH_TOKEN")
mcp_insecure = args.mcp_insecure_tls or (os.getenv("MCP_INSECURE_TLS", "0") in ("1","true","TRUE","yes","YES"))
mcp_insecure = args.mcp_insecure_tls or _truthy(os.getenv("MCP_INSECURE_TLS", "0"))
llm_base = args.llm_base_url or os.getenv("LLM_BASE_URL", "https://api.openai.com")
llm_key = args.llm_api_key or os.getenv("LLM_API_KEY", "")
llm_model = args.llm_model or os.getenv("LLM_MODEL", "gpt-4o-mini")
llm_insecure = args.llm_insecure_tls or (os.getenv("LLM_INSECURE_TLS", "0") in ("1","true","TRUE","yes","YES"))
llm_chat_path = args.llm_chat_path or os.getenv("LLM_CHAT_PATH", "/v1/chat/completions")
llm_insecure = args.llm_insecure_tls or _truthy(os.getenv("LLM_INSECURE_TLS", "0"))
llm_json_mode = args.llm_json_mode or _truthy(os.getenv("LLM_JSON_MODE", "0"))
if not mcp_endpoint:
console.print("[bold red]MCP endpoint is required (set MCP_ENDPOINT or --mcp-endpoint)[/bold red]")
raise SystemExit(2)
if "openai.com" in llm_base and not llm_key:
console.print("[bold red]LLM_API_KEY is required for OpenAI[/bold red]")
console.print("[bold red]MCP_ENDPOINT missing (or --mcp-endpoint)[/bold red]")
raise SystemExit(2)
run_id = args.run_id or str(uuid.uuid4())
ui = UIState(run_id=run_id)
mcp = MCPClient(mcp_endpoint, mcp_token, trace, insecure_tls=mcp_insecure)
llm = LLMClient(llm_base, llm_key, llm_model, trace, insecure_tls=llm_insecure)
agent = Agent(mcp, llm, trace, debug=args.debug)
llm = LLMClient(
llm_base, llm_key, llm_model, trace,
insecure_tls=llm_insecure,
chat_path=llm_chat_path,
json_mode=llm_json_mode,
)
agent = Agent(mcp, llm, trace)
async def runner():
try:
@ -546,100 +524,54 @@ async def cmd_run(args: argparse.Namespace):
if args.debug:
tb = traceback.format_exc()
trace.write({"type": "error.traceback", "traceback": tb})
ui.last_error += "\n" + tb
ui.last_error += "\\n" + tb
finally:
await mcp.close()
await llm.close()
task = asyncio.create_task(runner())
t = asyncio.create_task(runner())
with Live(render(ui), refresh_per_second=8, console=console):
while not task.done():
while not t.done():
await asyncio.sleep(0.1)
console.print(render(ui))
if ui.phase == "error":
raise SystemExit(1)
async def cmd_intent(args: argparse.Namespace):
console = Console()
trace = TraceLogger(args.trace)
mcp_endpoint = args.mcp_endpoint or os.getenv("MCP_ENDPOINT", "")
mcp_token = args.mcp_auth_token or os.getenv("MCP_AUTH_TOKEN")
mcp_insecure = args.mcp_insecure_tls or (os.getenv("MCP_INSECURE_TLS", "0") in ("1","true","TRUE","yes","YES"))
if not mcp_endpoint:
console.print("[bold red]MCP endpoint is required[/bold red]")
raise SystemExit(2)
payload = {
"run_id": args.run_id,
"audience": args.audience,
"goals": args.goals,
"constraints": {},
"updated_at": now_iso()
}
for kv in (args.constraint or []):
if "=" in kv:
k, v = kv.split("=", 1)
payload["constraints"][k] = v
mcp = MCPClient(mcp_endpoint, mcp_token, trace, insecure_tls=mcp_insecure)
try:
await mcp.call_tool("catalog_upsert", {
"kind": "intent",
"key": f"intent/{args.run_id}",
"document": json.dumps(payload, ensure_ascii=False),
"tags": f"run:{args.run_id}"
})
console.print("[green]Intent stored[/green]")
finally:
await mcp.close()
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(prog="discover_cli", description="Database Discovery Agent (Async CLI)")
sub = p.add_subparsers(dest="cmd", required=True)
common = argparse.ArgumentParser(add_help=False)
common.add_argument("--mcp-endpoint", default=None, help="MCP JSON-RPC endpoint (or MCP_ENDPOINT env)")
common.add_argument("--mcp-auth-token", default=None, help="MCP auth token (or MCP_AUTH_TOKEN env)")
common.add_argument("--mcp-insecure-tls", action="store_true", help="Disable MCP TLS verification (like curl -k)")
common.add_argument("--llm-base-url", default=None, help="OpenAI-compatible base URL (or LLM_BASE_URL env)")
common.add_argument("--llm-api-key", default=None, help="LLM API key (or LLM_API_KEY env)")
common.add_argument("--llm-model", default=None, help="LLM model (or LLM_MODEL env)")
common.add_argument("--llm-insecure-tls", action="store_true", help="Disable LLM TLS verification")
common.add_argument("--trace", default=None, help="Write JSONL trace to this file")
common.add_argument("--debug", action="store_true", help="Show stack traces")
prun = sub.add_parser("run", parents=[common], help="Run discovery")
prun.add_argument("--run-id", default=None, help="Optional run id (uuid). If omitted, generated.")
prun.add_argument("--schema", default=None, help="Optional schema to focus on")
common.add_argument("--mcp-endpoint", default=None)
common.add_argument("--mcp-auth-token", default=None)
common.add_argument("--mcp-insecure-tls", action="store_true")
common.add_argument("--llm-base-url", default=None)
common.add_argument("--llm-api-key", default=None)
common.add_argument("--llm-model", default=None)
common.add_argument("--llm-chat-path", default=None, help="e.g. /v1/chat/completions or /v4/chat/completions")
common.add_argument("--llm-insecure-tls", action="store_true")
common.add_argument("--llm-json-mode", action="store_true")
common.add_argument("--trace", default=None)
common.add_argument("--debug", action="store_true")
prun = sub.add_parser("run", parents=[common])
prun.add_argument("--run-id", default=None)
prun.add_argument("--schema", default=None)
prun.add_argument("--max-iterations", type=int, default=6)
prun.add_argument("--tasks-per-iter", type=int, default=3)
prun.set_defaults(func=cmd_run)
pint = sub.add_parser("intent", parents=[common], help="Set user intent for a run (stored in MCP catalog)")
pint.add_argument("--run-id", required=True)
pint.add_argument("--audience", default="mixed")
pint.add_argument("--goals", nargs="*", default=["qna"])
pint.add_argument("--constraint", action="append", help="constraint as key=value; repeatable")
pint.set_defaults(func=cmd_intent)
return p
def main():
parser = build_parser()
args = parser.parse_args()
args = build_parser().parse_args()
try:
asyncio.run(args.func(args))
except KeyboardInterrupt:
Console().print("\n[yellow]Interrupted[/yellow]")
Console().print("\\n[yellow]Interrupted[/yellow]")
raise SystemExit(130)
if __name__ == "__main__":
main()

Loading…
Cancel
Save