mirror of https://github.com/sysown/proxysql
Add a database discovery agent prototype that uses LLMs to explore databases through the MCP Query endpoint. Includes: - Rich CLI (discover_cli.py): Working async CLI with Rich TUI, proper MCP tools/call JSON-RPC method, and full tracing support - FastAPI_deprecated_POC: Early prototype with incorrect MCP protocol, kept for reference only The Rich CLI version implements a multi-expert agent architecture: - Planner: Chooses next tasks based on state - Structural Expert: Analyzes table structure and relationships - Statistical Expert: Profiles tables and columns - Semantic Expert: Infers domain meaning - Query Expert: Validates access patternspull/5310/head
parent
119ca5003a
commit
f2ca750c05
@ -0,0 +1,15 @@
|
||||
# Python virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
|
||||
# Trace files (optional - comment out if you want to commit traces)
|
||||
trace.jsonl
|
||||
*.jsonl
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
@ -0,0 +1,18 @@
|
||||
# DEPRECATED - Proof of Concept Only
|
||||
|
||||
This FastAPI implementation was an initial prototype and **is not working**.
|
||||
|
||||
The MCP protocol implementation here is incorrect - it attempts to call tool names directly as JSON-RPC methods instead of using the proper `tools/call` wrapper.
|
||||
|
||||
## Use the Rich CLI Instead
|
||||
|
||||
For a working implementation, use the **Rich CLI** version in the `../Rich/` directory:
|
||||
- `Rich/discover_cli.py` - Working async CLI with Rich TUI
|
||||
- Proper MCP `tools/call` JSON-RPC method
|
||||
- Full tracing and debugging support
|
||||
|
||||
## Status
|
||||
|
||||
- Do NOT attempt to run this code
|
||||
- Kept for reference/archival purposes only
|
||||
- May be removed in future commits
|
||||
@ -0,0 +1,601 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, AsyncGenerator, Literal, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
|
||||
# ============================================================
|
||||
# MCP client (JSON-RPC)
|
||||
# ============================================================
|
||||
|
||||
class MCPError(RuntimeError):
|
||||
pass
|
||||
|
||||
class MCPClient:
|
||||
def __init__(self, endpoint: str, auth_token: Optional[str] = None, timeout_sec: float = 120.0):
|
||||
self.endpoint = endpoint
|
||||
self.auth_token = auth_token
|
||||
self._client = httpx.AsyncClient(timeout=timeout_sec)
|
||||
|
||||
async def call(self, method: str, params: Dict[str, Any]) -> Any:
|
||||
req_id = str(uuid.uuid4())
|
||||
payload = {"jsonrpc": "2.0", "id": req_id, "method": method, "params": params}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||
r = await self._client.post(self.endpoint, json=payload, headers=headers)
|
||||
if r.status_code != 200:
|
||||
raise MCPError(f"MCP HTTP {r.status_code}: {r.text}")
|
||||
data = r.json()
|
||||
if "error" in data:
|
||||
raise MCPError(f"MCP error: {data['error']}")
|
||||
return data.get("result")
|
||||
|
||||
async def close(self):
|
||||
await self._client.aclose()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# OpenAI-compatible LLM client (works with OpenAI or local servers that mimic it)
|
||||
# ============================================================
|
||||
|
||||
class LLMError(RuntimeError):
|
||||
pass
|
||||
|
||||
class LLMClient:
|
||||
"""
|
||||
Calls an OpenAI-compatible /v1/chat/completions endpoint.
|
||||
Configure via env:
|
||||
LLM_BASE_URL (default: https://api.openai.com)
|
||||
LLM_API_KEY
|
||||
LLM_MODEL (default: gpt-4o-mini) # change as needed
|
||||
For local llama.cpp or vLLM OpenAI-compatible server: set LLM_BASE_URL accordingly.
|
||||
"""
|
||||
def __init__(self, base_url: str, api_key: str, model: str, timeout_sec: float = 120.0):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self._client = httpx.AsyncClient(timeout=timeout_sec)
|
||||
|
||||
async def chat_json(self, system: str, user: str, max_tokens: int = 1200) -> Dict[str, Any]:
|
||||
url = f"{self.base_url}/v1/chat/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"temperature": 0.2,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
],
|
||||
}
|
||||
|
||||
r = await self._client.post(url, json=payload, headers=headers)
|
||||
if r.status_code != 200:
|
||||
raise LLMError(f"LLM HTTP {r.status_code}: {r.text}")
|
||||
|
||||
data = r.json()
|
||||
try:
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
except Exception:
|
||||
raise LLMError(f"Unexpected LLM response: {data}")
|
||||
|
||||
# Strict JSON-only contract
|
||||
try:
|
||||
return json.loads(content)
|
||||
except Exception:
|
||||
# one repair attempt
|
||||
repair_system = "You are a JSON repair tool. Return ONLY valid JSON, no prose."
|
||||
repair_user = f"Fix this into valid JSON only:\n\n{content}"
|
||||
r2 = await self._client.post(url, json={
|
||||
"model": self.model,
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 1200,
|
||||
"messages": [
|
||||
{"role":"system","content":repair_system},
|
||||
{"role":"user","content":repair_user},
|
||||
],
|
||||
}, headers=headers)
|
||||
if r2.status_code != 200:
|
||||
raise LLMError(f"LLM repair HTTP {r2.status_code}: {r2.text}")
|
||||
data2 = r2.json()
|
||||
content2 = data2["choices"][0]["message"]["content"]
|
||||
try:
|
||||
return json.loads(content2)
|
||||
except Exception as e:
|
||||
raise LLMError(f"LLM returned non-JSON (even after repair): {content2}") from e
|
||||
|
||||
async def close(self):
|
||||
await self._client.aclose()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Shared schemas (LLM contracts)
|
||||
# ============================================================
|
||||
|
||||
ExpertName = Literal["planner", "structural", "statistical", "semantic", "query"]
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
name: str
|
||||
args: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class CatalogWrite(BaseModel):
|
||||
kind: str
|
||||
key: str
|
||||
document: str
|
||||
tags: Optional[str] = None
|
||||
links: Optional[str] = None
|
||||
|
||||
class QuestionForUser(BaseModel):
|
||||
question_id: str
|
||||
title: str
|
||||
prompt: str
|
||||
options: Optional[List[str]] = None
|
||||
|
||||
class ExpertAct(BaseModel):
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
notes: Optional[str] = None
|
||||
|
||||
class ExpertReflect(BaseModel):
|
||||
catalog_writes: List[CatalogWrite] = Field(default_factory=list)
|
||||
insights: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
questions_for_user: List[QuestionForUser] = Field(default_factory=list)
|
||||
|
||||
class PlannedTask(BaseModel):
|
||||
expert: ExpertName
|
||||
goal: str
|
||||
schema: str
|
||||
table: Optional[str] = None
|
||||
priority: float = 0.5
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Tool allow-lists per expert (from your MCP tools/list) :contentReference[oaicite:1]{index=1}
|
||||
# ============================================================
|
||||
|
||||
TOOLS = {
|
||||
"list_schemas","list_tables","describe_table","get_constraints",
|
||||
"table_profile","column_profile","sample_rows","sample_distinct",
|
||||
"run_sql_readonly","explain_sql","suggest_joins","find_reference_candidates",
|
||||
"catalog_upsert","catalog_get","catalog_search","catalog_list","catalog_merge","catalog_delete"
|
||||
}
|
||||
|
||||
ALLOWED_TOOLS: Dict[ExpertName, set] = {
|
||||
"planner": {"catalog_search","catalog_list","catalog_get"}, # planner reads state only
|
||||
"structural": {"describe_table","get_constraints","suggest_joins","find_reference_candidates","catalog_search","catalog_get","catalog_list"},
|
||||
"statistical": {"table_profile","column_profile","sample_rows","sample_distinct","catalog_search","catalog_get","catalog_list"},
|
||||
"semantic": {"sample_rows","catalog_search","catalog_get","catalog_list"},
|
||||
"query": {"explain_sql","run_sql_readonly","catalog_search","catalog_get","catalog_list"},
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# Prompts
|
||||
# ============================================================
|
||||
|
||||
PLANNER_SYSTEM = """You are the Planner agent for a database discovery system.
|
||||
You plan a small set of next tasks for specialist experts. Output ONLY JSON.
|
||||
|
||||
Rules:
|
||||
- Produce 1 to 6 tasks maximum.
|
||||
- Prefer high value tasks: relationship mapping, profiling key tables, domain inference.
|
||||
- Use schema/table names provided.
|
||||
- If user intent exists in catalog, prioritize accordingly.
|
||||
- Each task must include: expert, goal, schema, table(optional), priority (0..1).
|
||||
|
||||
Output schema:
|
||||
{ "tasks": [ { "expert": "...", "goal":"...", "schema":"...", "table":"optional", "priority":0.0 } ] }
|
||||
"""
|
||||
|
||||
EXPERT_ACT_SYSTEM_TEMPLATE = """You are the {expert} expert agent in a database discovery system.
|
||||
You can request MCP tools by returning JSON.
|
||||
|
||||
Return ONLY JSON in this schema:
|
||||
{{
|
||||
"tool_calls": [{{"name":"tool_name","args":{{...}}}}, ...],
|
||||
"notes": "optional brief note"
|
||||
}}
|
||||
|
||||
Rules:
|
||||
- Only call tools from this allowed set: {allowed_tools}
|
||||
- Keep tool calls minimal and targeted.
|
||||
- Prefer sampling/profiling to full scans.
|
||||
- If unsure, request small samples (sample_rows) and/or lightweight profiles.
|
||||
"""
|
||||
|
||||
EXPERT_REFLECT_SYSTEM_TEMPLATE = """You are the {expert} expert agent. You are given results of tool calls.
|
||||
Synthesize them into durable catalog entries and (optionally) questions for the user.
|
||||
|
||||
Return ONLY JSON in this schema:
|
||||
{{
|
||||
"catalog_writes": [{{"kind":"...","key":"...","document":"...","tags":"optional","links":"optional"}}, ...],
|
||||
"insights": [{{"claim":"...","confidence":0.0,"evidence":[...]}}, ...],
|
||||
"questions_for_user": [{{"question_id":"...","title":"...","prompt":"...","options":["..."]}}, ...]
|
||||
}}
|
||||
|
||||
Rules:
|
||||
- catalog_writes.document MUST be a JSON string (i.e., json.dumps payload).
|
||||
- Use stable keys so entries can be updated: e.g. table/<schema>.<table>, col/<schema>.<table>.<col>, hypothesis/<id>, intent/<run_id>
|
||||
- If you detect ambiguity about goal/audience, ask ONE focused question.
|
||||
"""
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Expert implementations
|
||||
# ============================================================
|
||||
|
||||
@dataclass
|
||||
class ExpertContext:
|
||||
run_id: str
|
||||
schema: str
|
||||
table: Optional[str]
|
||||
user_intent: Optional[Dict[str, Any]]
|
||||
catalog_snippets: List[Dict[str, Any]]
|
||||
|
||||
class Expert:
|
||||
def __init__(self, name: ExpertName, llm: LLMClient, mcp: MCPClient, emit):
|
||||
self.name = name
|
||||
self.llm = llm
|
||||
self.mcp = mcp
|
||||
self.emit = emit
|
||||
|
||||
async def act(self, ctx: ExpertContext) -> ExpertAct:
|
||||
system = EXPERT_ACT_SYSTEM_TEMPLATE.format(
|
||||
expert=self.name,
|
||||
allowed_tools=sorted(ALLOWED_TOOLS[self.name])
|
||||
)
|
||||
user = {
|
||||
"run_id": ctx.run_id,
|
||||
"schema": ctx.schema,
|
||||
"table": ctx.table,
|
||||
"user_intent": ctx.user_intent,
|
||||
"catalog_snippets": ctx.catalog_snippets[:10],
|
||||
"request": f"Choose the best MCP tool calls for your expert role ({self.name}) to advance discovery."
|
||||
}
|
||||
raw = await self.llm.chat_json(system, json.dumps(user, ensure_ascii=False), max_tokens=900)
|
||||
try:
|
||||
return ExpertAct.model_validate(raw)
|
||||
except ValidationError as e:
|
||||
raise LLMError(f"{self.name} act schema invalid: {e}\nraw={raw}")
|
||||
|
||||
async def reflect(self, ctx: ExpertContext, tool_results: List[Dict[str, Any]]) -> ExpertReflect:
|
||||
system = EXPERT_REFLECT_SYSTEM_TEMPLATE.format(expert=self.name)
|
||||
user = {
|
||||
"run_id": ctx.run_id,
|
||||
"schema": ctx.schema,
|
||||
"table": ctx.table,
|
||||
"user_intent": ctx.user_intent,
|
||||
"catalog_snippets": ctx.catalog_snippets[:10],
|
||||
"tool_results": tool_results,
|
||||
"instruction": "Write catalog entries that capture durable discoveries."
|
||||
}
|
||||
raw = await self.llm.chat_json(system, json.dumps(user, ensure_ascii=False), max_tokens=1200)
|
||||
try:
|
||||
return ExpertReflect.model_validate(raw)
|
||||
except ValidationError as e:
|
||||
raise LLMError(f"{self.name} reflect schema invalid: {e}\nraw={raw}")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Orchestrator
|
||||
# ============================================================
|
||||
|
||||
class Orchestrator:
|
||||
def __init__(self, run_id: str, mcp: MCPClient, llm: LLMClient, emit):
|
||||
self.run_id = run_id
|
||||
self.mcp = mcp
|
||||
self.llm = llm
|
||||
self.emit = emit
|
||||
|
||||
self.experts: Dict[ExpertName, Expert] = {
|
||||
"structural": Expert("structural", llm, mcp, emit),
|
||||
"statistical": Expert("statistical", llm, mcp, emit),
|
||||
"semantic": Expert("semantic", llm, mcp, emit),
|
||||
"query": Expert("query", llm, mcp, emit),
|
||||
"planner": Expert("planner", llm, mcp, emit), # not used as Expert; planner has special prompt
|
||||
}
|
||||
|
||||
async def _catalog_search(self, query: str, kind: Optional[str] = None, tags: Optional[str] = None, limit: int = 10):
|
||||
params = {"query": query, "limit": limit, "offset": 0}
|
||||
if kind:
|
||||
params["kind"] = kind
|
||||
if tags:
|
||||
params["tags"] = tags
|
||||
return await self.mcp.call("catalog_search", params)
|
||||
|
||||
async def _get_user_intent(self) -> Optional[Dict[str, Any]]:
|
||||
# Convention: kind="intent", key="intent/<run_id>"
|
||||
try:
|
||||
res = await self.mcp.call("catalog_get", {"kind": "intent", "key": f"intent/{self.run_id}"})
|
||||
if not res:
|
||||
return None
|
||||
doc = res.get("document")
|
||||
if not doc:
|
||||
return None
|
||||
return json.loads(doc)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _upsert_question(self, q: QuestionForUser):
|
||||
payload = {
|
||||
"run_id": self.run_id,
|
||||
"question_id": q.question_id,
|
||||
"title": q.title,
|
||||
"prompt": q.prompt,
|
||||
"options": q.options,
|
||||
"created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
}
|
||||
await self.mcp.call("catalog_upsert", {
|
||||
"kind": "question",
|
||||
"key": f"question/{self.run_id}/{q.question_id}",
|
||||
"document": json.dumps(payload, ensure_ascii=False),
|
||||
"tags": f"run:{self.run_id}"
|
||||
})
|
||||
|
||||
async def _execute_tool_calls(self, expert: ExpertName, calls: List[ToolCall]) -> List[Dict[str, Any]]:
|
||||
results = []
|
||||
for c in calls:
|
||||
if c.name not in TOOLS:
|
||||
raise MCPError(f"Unknown tool: {c.name}")
|
||||
if c.name not in ALLOWED_TOOLS[expert]:
|
||||
raise MCPError(f"Tool not allowed for {expert}: {c.name}")
|
||||
await self.emit("tool", "call", {"expert": expert, "name": c.name, "args": c.args})
|
||||
res = await self.mcp.call(c.name, c.args)
|
||||
results.append({"tool": c.name, "args": c.args, "result": res})
|
||||
return results
|
||||
|
||||
async def _apply_catalog_writes(self, expert: ExpertName, writes: List[CatalogWrite]):
|
||||
for w in writes:
|
||||
await self.emit("catalog", "upsert", {"expert": expert, "kind": w.kind, "key": w.key})
|
||||
await self.mcp.call("catalog_upsert", {
|
||||
"kind": w.kind,
|
||||
"key": w.key,
|
||||
"document": w.document,
|
||||
"tags": w.tags or f"run:{self.run_id},expert:{expert}",
|
||||
"links": w.links,
|
||||
})
|
||||
|
||||
async def _planner(self, schema: str, tables: List[str], user_intent: Optional[Dict[str, Any]]) -> List[PlannedTask]:
|
||||
# Pull a small slice of catalog state to inform planning
|
||||
snippets = []
|
||||
try:
|
||||
sres = await self._catalog_search(query=f"run:{self.run_id}", limit=10)
|
||||
items = sres.get("items") or sres.get("results") or []
|
||||
snippets = items[:10]
|
||||
except Exception:
|
||||
snippets = []
|
||||
|
||||
user = {
|
||||
"run_id": self.run_id,
|
||||
"schema": schema,
|
||||
"tables": tables[:200],
|
||||
"user_intent": user_intent,
|
||||
"catalog_snippets": snippets,
|
||||
"instruction": "Plan next tasks."
|
||||
}
|
||||
raw = await self.llm.chat_json(PLANNER_SYSTEM, json.dumps(user, ensure_ascii=False), max_tokens=900)
|
||||
try:
|
||||
tasks_raw = raw.get("tasks", [])
|
||||
tasks = [PlannedTask.model_validate(t) for t in tasks_raw]
|
||||
# enforce allowed experts
|
||||
tasks = [t for t in tasks if t.expert in ("structural","statistical","semantic","query")]
|
||||
tasks.sort(key=lambda x: x.priority, reverse=True)
|
||||
return tasks[:6]
|
||||
except ValidationError as e:
|
||||
raise LLMError(f"Planner schema invalid: {e}\nraw={raw}")
|
||||
|
||||
async def run(self, schema: Optional[str], max_iterations: int, tasks_per_iter: int):
|
||||
await self.emit("run", "starting", {"run_id": self.run_id})
|
||||
|
||||
schemas_res = await self.mcp.call("list_schemas", {"page_size": 50})
|
||||
schemas = schemas_res.get("schemas") or schemas_res.get("items") or schemas_res.get("result") or []
|
||||
if not schemas:
|
||||
raise MCPError("No schemas returned by list_schemas")
|
||||
|
||||
chosen_schema = schema or (schemas[0]["name"] if isinstance(schemas[0], dict) else schemas[0])
|
||||
await self.emit("run", "schema_selected", {"schema": chosen_schema})
|
||||
|
||||
tables_res = await self.mcp.call("list_tables", {"schema": chosen_schema, "page_size": 500})
|
||||
tables = tables_res.get("tables") or tables_res.get("items") or tables_res.get("result") or []
|
||||
table_names = [(t["name"] if isinstance(t, dict) else t) for t in tables]
|
||||
if not table_names:
|
||||
raise MCPError(f"No tables returned by list_tables(schema={chosen_schema})")
|
||||
|
||||
await self.emit("run", "tables_listed", {"count": len(table_names)})
|
||||
|
||||
# Track simple diminishing returns
|
||||
last_insight_hashes: List[str] = []
|
||||
|
||||
for it in range(1, max_iterations + 1):
|
||||
user_intent = await self._get_user_intent()
|
||||
|
||||
tasks = await self._planner(chosen_schema, table_names, user_intent)
|
||||
await self.emit("run", "tasks_planned", {"iteration": it, "tasks": [t.model_dump() for t in tasks]})
|
||||
|
||||
if not tasks:
|
||||
await self.emit("run", "finished", {"run_id": self.run_id, "reason": "planner returned no tasks"})
|
||||
return
|
||||
|
||||
# Execute a bounded number per iteration
|
||||
executed = 0
|
||||
new_insights = 0
|
||||
|
||||
for task in tasks:
|
||||
if executed >= tasks_per_iter:
|
||||
break
|
||||
executed += 1
|
||||
|
||||
expert_name: ExpertName = task.expert
|
||||
expert = self.experts[expert_name]
|
||||
|
||||
# Collect small relevant context from catalog
|
||||
cat_snips = []
|
||||
try:
|
||||
# Pull table-specific snippets if possible
|
||||
q = task.table or ""
|
||||
sres = await self._catalog_search(query=q, limit=10)
|
||||
cat_snips = (sres.get("items") or sres.get("results") or [])[:10]
|
||||
except Exception:
|
||||
cat_snips = []
|
||||
|
||||
ctx = ExpertContext(
|
||||
run_id=self.run_id,
|
||||
schema=task.schema,
|
||||
table=task.table,
|
||||
user_intent=user_intent,
|
||||
catalog_snippets=cat_snips,
|
||||
)
|
||||
|
||||
await self.emit("run", "task_start", {"iteration": it, "task": task.model_dump()})
|
||||
|
||||
# 1) Expert ACT: request tools
|
||||
act = await expert.act(ctx)
|
||||
tool_results = await self._execute_tool_calls(expert_name, act.tool_calls)
|
||||
|
||||
# 2) Expert REFLECT: write catalog entries
|
||||
ref = await expert.reflect(ctx, tool_results)
|
||||
await self._apply_catalog_writes(expert_name, ref.catalog_writes)
|
||||
|
||||
# store questions (if any)
|
||||
for q in ref.questions_for_user:
|
||||
await self._upsert_question(q)
|
||||
|
||||
# crude diminishing return tracking via insight hashes
|
||||
for ins in ref.insights:
|
||||
h = json.dumps(ins, sort_keys=True)
|
||||
if h not in last_insight_hashes:
|
||||
last_insight_hashes.append(h)
|
||||
new_insights += 1
|
||||
last_insight_hashes = last_insight_hashes[-50:]
|
||||
|
||||
await self.emit("run", "task_done", {"iteration": it, "expert": expert_name, "new_insights": new_insights})
|
||||
|
||||
await self.emit("run", "iteration_done", {"iteration": it, "executed": executed, "new_insights": new_insights})
|
||||
|
||||
# Simple stop: if 2 iterations in a row produced no new insights
|
||||
if it >= 2 and new_insights == 0:
|
||||
await self.emit("run", "finished", {"run_id": self.run_id, "reason": "diminishing returns"})
|
||||
return
|
||||
|
||||
await self.emit("run", "finished", {"run_id": self.run_id, "reason": "max_iterations reached"})
|
||||
|
||||
|
||||
# ============================================================
|
||||
# FastAPI + SSE
|
||||
# ============================================================
|
||||
|
||||
app = FastAPI(title="Database Discovery Agent (LLM + Multi-Expert)")
|
||||
|
||||
RUNS: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
class RunCreate(BaseModel):
|
||||
schema: Optional[str] = None
|
||||
max_iterations: int = 8
|
||||
tasks_per_iter: int = 3
|
||||
|
||||
def sse_format(event: Dict[str, Any]) -> str:
|
||||
return f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
|
||||
|
||||
async def event_emitter(q: asyncio.Queue) -> AsyncGenerator[bytes, None]:
|
||||
while True:
|
||||
ev = await q.get()
|
||||
yield sse_format(ev).encode("utf-8")
|
||||
if ev.get("type") == "run" and ev.get("message") in ("finished", "error"):
|
||||
return
|
||||
|
||||
@app.post("/runs")
|
||||
async def create_run(req: RunCreate):
|
||||
# LLM env
|
||||
llm_base = os.getenv("LLM_BASE_URL", "https://api.openai.com")
|
||||
llm_key = os.getenv("LLM_API_KEY", "")
|
||||
llm_model = os.getenv("LLM_MODEL", "gpt-4o-mini")
|
||||
|
||||
if not llm_key and "openai.com" in llm_base:
|
||||
raise HTTPException(status_code=400, detail="Set LLM_API_KEY (or use a local OpenAI-compatible server).")
|
||||
|
||||
# MCP env
|
||||
mcp_endpoint = os.getenv("MCP_ENDPOINT", "http://localhost:6071/mcp/query")
|
||||
mcp_token = os.getenv("MCP_AUTH_TOKEN")
|
||||
|
||||
run_id = str(uuid.uuid4())
|
||||
q: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def emit(ev_type: str, message: str, data: Optional[Dict[str, Any]] = None):
|
||||
await q.put({
|
||||
"ts": time.time(),
|
||||
"run_id": run_id,
|
||||
"type": ev_type,
|
||||
"message": message,
|
||||
"data": data or {}
|
||||
})
|
||||
|
||||
mcp = MCPClient(mcp_endpoint, auth_token=mcp_token)
|
||||
llm = LLMClient(llm_base, llm_key, llm_model)
|
||||
|
||||
async def runner():
|
||||
try:
|
||||
orch = Orchestrator(run_id, mcp, llm, emit)
|
||||
await orch.run(schema=req.schema, max_iterations=req.max_iterations, tasks_per_iter=req.tasks_per_iter)
|
||||
except Exception as e:
|
||||
await emit("run", "error", {"error": str(e)})
|
||||
finally:
|
||||
await mcp.close()
|
||||
await llm.close()
|
||||
|
||||
task = asyncio.create_task(runner())
|
||||
RUNS[run_id] = {"queue": q, "task": task}
|
||||
return {"run_id": run_id}
|
||||
|
||||
@app.get("/runs/{run_id}/events")
|
||||
async def stream_events(run_id: str):
|
||||
run = RUNS.get(run_id)
|
||||
if not run:
|
||||
raise HTTPException(status_code=404, detail="run_id not found")
|
||||
return StreamingResponse(event_emitter(run["queue"]), media_type="text/event-stream")
|
||||
|
||||
class IntentUpsert(BaseModel):
|
||||
audience: Optional[str] = None # "dev"|"support"|"analytics"|"end_user"|...
|
||||
goals: Optional[List[str]] = None # e.g. ["qna","documentation","analytics"]
|
||||
constraints: Optional[Dict[str, Any]] = None
|
||||
|
||||
@app.post("/runs/{run_id}/intent")
|
||||
async def upsert_intent(run_id: str, intent: IntentUpsert):
|
||||
# Writes to MCP catalog so experts can read it immediately
|
||||
mcp_endpoint = os.getenv("MCP_ENDPOINT", "http://localhost:6071/mcp/query")
|
||||
mcp_token = os.getenv("MCP_AUTH_TOKEN")
|
||||
mcp = MCPClient(mcp_endpoint, auth_token=mcp_token)
|
||||
try:
|
||||
payload = intent.model_dump(exclude_none=True)
|
||||
payload["run_id"] = run_id
|
||||
payload["updated_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
await mcp.call("catalog_upsert", {
|
||||
"kind": "intent",
|
||||
"key": f"intent/{run_id}",
|
||||
"document": json.dumps(payload, ensure_ascii=False),
|
||||
"tags": f"run:{run_id}"
|
||||
})
|
||||
return {"ok": True}
|
||||
finally:
|
||||
await mcp.close()
|
||||
|
||||
@app.get("/runs/{run_id}/questions")
|
||||
async def list_questions(run_id: str):
|
||||
mcp_endpoint = os.getenv("MCP_ENDPOINT", "http://localhost:6071/mcp/query")
|
||||
mcp_token = os.getenv("MCP_AUTH_TOKEN")
|
||||
mcp = MCPClient(mcp_endpoint, auth_token=mcp_token)
|
||||
try:
|
||||
res = await mcp.call("catalog_search", {"query": f"question/{run_id}/", "limit": 50, "offset": 0})
|
||||
return res
|
||||
finally:
|
||||
await mcp.close()
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
fastapi==0.115.0
|
||||
uvicorn[standard]==0.30.6
|
||||
httpx==0.27.0
|
||||
pydantic==2.8.2
|
||||
python-dotenv==1.0.1
|
||||
@ -0,0 +1,68 @@
|
||||
# TODO — Future Enhancements
|
||||
|
||||
This prototype prioritizes **runnability and debuggability**. Suggested next steps:
|
||||
|
||||
---
|
||||
|
||||
## 1) Catalog consistency
|
||||
|
||||
- Standardize catalog document structure (envelope with provenance + confidence)
|
||||
- Enforce key naming conventions (structure/table, stats/col, semantic/entity, report, …)
|
||||
|
||||
---
|
||||
|
||||
## 2) Better expert strategies
|
||||
|
||||
- Structural: relationship graph (constraints + join candidates)
|
||||
- Statistical: prioritize high-signal columns; sampling-first for big tables
|
||||
- Semantic: evidence-based claims, fewer hallucinations, ask user only when needed
|
||||
- Query: safe mode (`explain_sql` by default; strict LIMIT for readonly SQL)
|
||||
|
||||
---
|
||||
|
||||
## 3) Coverage and confidence
|
||||
|
||||
- Track coverage: tables discovered vs analyzed vs profiled
|
||||
- Compute confidence heuristics and use them for stopping/checkpoints
|
||||
|
||||
---
|
||||
|
||||
## 4) Planning improvements
|
||||
|
||||
- Task de-duplication (avoid repeating the same work)
|
||||
- Heuristics for table prioritization if planner struggles early
|
||||
|
||||
---
|
||||
|
||||
## 5) Add commands
|
||||
|
||||
- `report --run-id <id>`: synthesize a readable report from catalog
|
||||
- `replay --trace trace.jsonl`: iterate prompts without hitting the DB
|
||||
|
||||
---
|
||||
|
||||
## 6) Optional UI upgrade
|
||||
|
||||
Move from Rich Live to **Textual** for:
|
||||
- scrolling logs
|
||||
- interactive question answering
|
||||
- better filtering and navigation
|
||||
|
||||
---
|
||||
|
||||
## 7) Controlled concurrency
|
||||
|
||||
Once stable:
|
||||
- run tasks concurrently with a semaphore
|
||||
- per-table locks to avoid duplication
|
||||
- keep catalog writes atomic per key
|
||||
|
||||
---
|
||||
|
||||
## 8) MCP enhancements (later)
|
||||
|
||||
After real usage:
|
||||
- batch table describes / batch column profiles
|
||||
- explicit row-count estimation tool
|
||||
- typed catalog documents (native JSON instead of string)
|
||||
|
||||
@ -0,0 +1,645 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Database Discovery Agent (Async CLI, Rich UI)
|
||||
|
||||
Key fixes vs earlier version:
|
||||
- MCP tools are invoked via JSON-RPC method **tools/call** (NOT by calling tool name as method).
|
||||
- Supports HTTPS + Bearer token + optional insecure TLS (self-signed certs).
|
||||
|
||||
Environment variables (or CLI flags):
|
||||
- MCP_ENDPOINT (e.g. https://127.0.0.1:6071/mcp/query)
|
||||
- MCP_AUTH_TOKEN (Bearer token, if required)
|
||||
- MCP_INSECURE_TLS=1 to disable TLS verification (like curl -k)
|
||||
|
||||
- LLM_BASE_URL (OpenAI-compatible base, e.g. https://api.openai.com)
|
||||
- LLM_API_KEY
|
||||
- LLM_MODEL
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Literal, Tuple
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
from rich.layout import Layout
|
||||
|
||||
|
||||
ExpertName = Literal["planner", "structural", "statistical", "semantic", "query"]
|
||||
|
||||
KNOWN_MCP_TOOLS = {
|
||||
"list_schemas", "list_tables", "describe_table", "get_constraints",
|
||||
"table_profile", "column_profile", "sample_rows", "sample_distinct",
|
||||
"run_sql_readonly", "explain_sql", "suggest_joins", "find_reference_candidates",
|
||||
"catalog_upsert", "catalog_get", "catalog_search", "catalog_list", "catalog_merge", "catalog_delete"
|
||||
}
|
||||
|
||||
ALLOWED_TOOLS: Dict[ExpertName, set] = {
|
||||
"planner": {"catalog_search", "catalog_list", "catalog_get"},
|
||||
"structural": {"describe_table", "get_constraints", "suggest_joins", "find_reference_candidates", "catalog_search", "catalog_get", "catalog_list"},
|
||||
"statistical": {"table_profile", "column_profile", "sample_rows", "sample_distinct", "catalog_search", "catalog_get", "catalog_list"},
|
||||
"semantic": {"sample_rows", "catalog_search", "catalog_get", "catalog_list"},
|
||||
"query": {"explain_sql", "run_sql_readonly", "catalog_search", "catalog_get", "catalog_list"},
|
||||
}
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
name: str
|
||||
args: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class PlannedTask(BaseModel):
|
||||
expert: ExpertName
|
||||
goal: str
|
||||
schema: str
|
||||
table: Optional[str] = None
|
||||
priority: float = 0.5
|
||||
|
||||
class PlannerOut(BaseModel):
|
||||
tasks: List[PlannedTask] = Field(default_factory=list)
|
||||
|
||||
class ExpertAct(BaseModel):
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
notes: Optional[str] = None
|
||||
|
||||
class CatalogWrite(BaseModel):
|
||||
kind: str
|
||||
key: str
|
||||
document: str
|
||||
tags: Optional[str] = None
|
||||
links: Optional[str] = None
|
||||
|
||||
class QuestionForUser(BaseModel):
|
||||
question_id: str
|
||||
title: str
|
||||
prompt: str
|
||||
options: Optional[List[str]] = None
|
||||
|
||||
class ExpertReflect(BaseModel):
|
||||
catalog_writes: List[CatalogWrite] = Field(default_factory=list)
|
||||
insights: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
questions_for_user: List[QuestionForUser] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TraceLogger:
|
||||
def __init__(self, path: Optional[str]):
|
||||
self.path = path
|
||||
|
||||
def write(self, record: Dict[str, Any]):
|
||||
if not self.path:
|
||||
return
|
||||
rec = dict(record)
|
||||
rec["ts"] = time.time()
|
||||
with open(self.path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
class MCPError(RuntimeError):
|
||||
pass
|
||||
|
||||
class MCPClient:
|
||||
def __init__(self, endpoint: str, auth_token: Optional[str], trace: TraceLogger, insecure_tls: bool = False):
|
||||
self.endpoint = endpoint
|
||||
self.auth_token = auth_token
|
||||
self.trace = trace
|
||||
self.client = httpx.AsyncClient(timeout=120.0, verify=(not insecure_tls))
|
||||
|
||||
async def rpc(self, method: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
req_id = str(uuid.uuid4())
|
||||
payload = {"jsonrpc": "2.0", "id": req_id, "method": method}
|
||||
if params is not None:
|
||||
payload["params"] = params
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||
|
||||
self.trace.write({"type": "mcp.rpc", "method": method, "params": params})
|
||||
r = await self.client.post(self.endpoint, json=payload, headers=headers)
|
||||
if r.status_code != 200:
|
||||
raise MCPError(f"MCP HTTP {r.status_code}: {r.text}")
|
||||
data = r.json()
|
||||
if "error" in data:
|
||||
raise MCPError(f"MCP error: {data['error']}")
|
||||
return data.get("result")
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Optional[Dict[str, Any]] = None) -> Any:
|
||||
if tool_name not in KNOWN_MCP_TOOLS:
|
||||
raise MCPError(f"Unknown tool: {tool_name}")
|
||||
args = arguments or {}
|
||||
self.trace.write({"type": "mcp.call", "tool": tool_name, "arguments": args})
|
||||
|
||||
result = await self.rpc("tools/call", {"name": tool_name, "arguments": args})
|
||||
self.trace.write({"type": "mcp.result", "tool": tool_name, "result": result})
|
||||
|
||||
# Expected: {"success": true, "result": ...}
|
||||
if isinstance(result, dict) and "success" in result:
|
||||
if not result.get("success", False):
|
||||
raise MCPError(f"MCP tool failed: {tool_name}: {result}")
|
||||
return result.get("result")
|
||||
return result
|
||||
|
||||
async def close(self):
|
||||
await self.client.aclose()
|
||||
|
||||
|
||||
class LLMError(RuntimeError):
|
||||
pass
|
||||
|
||||
class LLMClient:
|
||||
def __init__(self, base_url: str, api_key: str, model: str, trace: TraceLogger, insecure_tls: bool = False):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.trace = trace
|
||||
self.client = httpx.AsyncClient(timeout=120.0, verify=(not insecure_tls))
|
||||
|
||||
async def chat_json(self, system: str, user: str, *, max_tokens: int = 1200) -> Dict[str, Any]:
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"temperature": 0.2,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
],
|
||||
}
|
||||
|
||||
self.trace.write({"type": "llm.request", "model": self.model, "system": system[:4000], "user": user[:8000]})
|
||||
r = await self.client.post(url, json=payload, headers=headers)
|
||||
if r.status_code != 200:
|
||||
raise LLMError(f"LLM HTTP {r.status_code}: {r.text}")
|
||||
data = r.json()
|
||||
try:
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
except Exception:
|
||||
raise LLMError(f"Unexpected LLM response: {data}")
|
||||
self.trace.write({"type": "llm.raw", "content": content})
|
||||
|
||||
try:
|
||||
return json.loads(content)
|
||||
except Exception:
|
||||
repair_payload = {
|
||||
"model": self.model,
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 1200,
|
||||
"messages": [
|
||||
{"role": "system", "content": "Return ONLY valid JSON, no prose."},
|
||||
{"role": "user", "content": f"Fix into valid JSON:\n\n{content}"},
|
||||
],
|
||||
}
|
||||
self.trace.write({"type": "llm.repair.request", "bad": content[:8000]})
|
||||
r2 = await self.client.post(url, json=repair_payload, headers=headers)
|
||||
if r2.status_code != 200:
|
||||
raise LLMError(f"LLM repair HTTP {r2.status_code}: {r2.text}")
|
||||
data2 = r2.json()
|
||||
content2 = data2["choices"][0]["message"]["content"]
|
||||
self.trace.write({"type": "llm.repair.raw", "content": content2})
|
||||
try:
|
||||
return json.loads(content2)
|
||||
except Exception as e:
|
||||
raise LLMError(f"LLM returned non-JSON after repair: {content2}") from e
|
||||
|
||||
|
||||
PLANNER_SYSTEM = """You are the Planner agent for a database discovery system.
|
||||
You plan a small set of next tasks for specialist experts. Output ONLY JSON.
|
||||
|
||||
Rules:
|
||||
- Produce 1 to 6 tasks maximum.
|
||||
- Prefer high-value tasks: mapping structure, finding relationships, profiling key tables, domain inference.
|
||||
- Consider user intent if provided.
|
||||
- Each task must include: expert, goal, schema, table(optional), priority (0..1).
|
||||
|
||||
Output schema:
|
||||
{"tasks":[{"expert":"structural|statistical|semantic|query","goal":"...","schema":"...","table":"optional","priority":0.0}]}
|
||||
"""
|
||||
|
||||
EXPERT_ACT_SYSTEM = """You are the {expert} expert agent.
|
||||
Return ONLY JSON in this schema:
|
||||
{{"tool_calls":[{{"name":"tool_name","args":{{...}}}}], "notes":"optional"}}
|
||||
|
||||
Rules:
|
||||
- Only call tools from: {allowed_tools}
|
||||
- Keep tool calls minimal (max 6).
|
||||
- Prefer sampling/profiling to full scans.
|
||||
- If unsure: sample_rows + lightweight profile first.
|
||||
"""
|
||||
|
||||
EXPERT_REFLECT_SYSTEM = """You are the {expert} expert agent. You are given results of tool calls.
|
||||
Synthesize durable catalog entries and (optionally) questions for the user.
|
||||
|
||||
Return ONLY JSON in this schema:
|
||||
{{
|
||||
"catalog_writes":[{{"kind":"...","key":"...","document":"JSON_STRING","tags":"optional","links":"optional"}}],
|
||||
"insights":[{{"claim":"...","confidence":0.0,"evidence":[...]}}],
|
||||
"questions_for_user":[{{"question_id":"...","title":"...","prompt":"...","options":["..."]}}]
|
||||
}}
|
||||
|
||||
Rules:
|
||||
- catalog_writes.document MUST be a JSON string (i.e. json.dumps of your payload).
|
||||
- Ask at most ONE question per reflect step, only if it materially changes exploration.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class UIState:
|
||||
run_id: str
|
||||
phase: str = "init"
|
||||
iteration: int = 0
|
||||
planned_tasks: List[PlannedTask] = None
|
||||
last_event: str = ""
|
||||
last_error: str = ""
|
||||
tool_calls: int = 0
|
||||
catalog_writes: int = 0
|
||||
insights: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.planned_tasks is None:
|
||||
self.planned_tasks = []
|
||||
|
||||
|
||||
def normalize_list(res: Any, keys: Tuple[str, ...]) -> List[Any]:
|
||||
if isinstance(res, list):
|
||||
return res
|
||||
if isinstance(res, dict):
|
||||
for k in keys:
|
||||
v = res.get(k)
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
return []
|
||||
|
||||
def item_name(x: Any) -> str:
|
||||
if isinstance(x, dict) and "name" in x:
|
||||
return str(x["name"])
|
||||
return str(x)
|
||||
|
||||
def now_iso() -> str:
|
||||
return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, mcp: MCPClient, llm: LLMClient, trace: TraceLogger, debug: bool):
|
||||
self.mcp = mcp
|
||||
self.llm = llm
|
||||
self.trace = trace
|
||||
self.debug = debug
|
||||
|
||||
async def planner(self, schema: str, tables: List[str], user_intent: Optional[Dict[str, Any]]) -> List[PlannedTask]:
|
||||
user = json.dumps({
|
||||
"schema": schema,
|
||||
"tables": tables[:300],
|
||||
"user_intent": user_intent,
|
||||
"instruction": "Plan next tasks."
|
||||
}, ensure_ascii=False)
|
||||
|
||||
raw = await self.llm.chat_json(PLANNER_SYSTEM, user, max_tokens=900)
|
||||
try:
|
||||
out = PlannerOut.model_validate(raw)
|
||||
except ValidationError as e:
|
||||
raise LLMError(f"Planner output invalid: {e}\nraw={raw}")
|
||||
|
||||
tasks = [t for t in out.tasks if t.expert in ("structural","statistical","semantic","query")]
|
||||
tasks.sort(key=lambda t: t.priority, reverse=True)
|
||||
return tasks[:6]
|
||||
|
||||
async def expert_act(self, expert: ExpertName, ctx: Dict[str, Any]) -> ExpertAct:
|
||||
system = EXPERT_ACT_SYSTEM.format(expert=expert, allowed_tools=sorted(ALLOWED_TOOLS[expert]))
|
||||
raw = await self.llm.chat_json(system, json.dumps(ctx, ensure_ascii=False), max_tokens=900)
|
||||
try:
|
||||
act = ExpertAct.model_validate(raw)
|
||||
except ValidationError as e:
|
||||
raise LLMError(f"{expert} ACT invalid: {e}\nraw={raw}")
|
||||
|
||||
act.tool_calls = act.tool_calls[:6]
|
||||
for c in act.tool_calls:
|
||||
if c.name not in KNOWN_MCP_TOOLS:
|
||||
raise MCPError(f"{expert} requested unknown tool: {c.name}")
|
||||
if c.name not in ALLOWED_TOOLS[expert]:
|
||||
raise MCPError(f"{expert} requested disallowed tool: {c.name}")
|
||||
return act
|
||||
|
||||
async def expert_reflect(self, expert: ExpertName, ctx: Dict[str, Any], tool_results: List[Dict[str, Any]]) -> ExpertReflect:
|
||||
system = EXPERT_REFLECT_SYSTEM.format(expert=expert)
|
||||
user = dict(ctx)
|
||||
user["tool_results"] = tool_results
|
||||
raw = await self.llm.chat_json(system, json.dumps(user, ensure_ascii=False), max_tokens=1200)
|
||||
try:
|
||||
ref = ExpertReflect.model_validate(raw)
|
||||
except ValidationError as e:
|
||||
raise LLMError(f"{expert} REFLECT invalid: {e}\nraw={raw}")
|
||||
return ref
|
||||
|
||||
async def apply_catalog_writes(self, writes: List[CatalogWrite]):
|
||||
for w in writes:
|
||||
await self.mcp.call_tool("catalog_upsert", {
|
||||
"kind": w.kind,
|
||||
"key": w.key,
|
||||
"document": w.document,
|
||||
"tags": w.tags,
|
||||
"links": w.links
|
||||
})
|
||||
|
||||
async def run(self, ui: UIState, schema: Optional[str], max_iterations: int, tasks_per_iter: int):
|
||||
ui.phase = "bootstrap"
|
||||
|
||||
schemas_res = await self.mcp.call_tool("list_schemas", {"page_size": 50})
|
||||
schemas = schemas_res if isinstance(schemas_res, list) else normalize_list(schemas_res, ("schemas","items","result"))
|
||||
if not schemas:
|
||||
raise MCPError("No schemas returned by MCP list_schemas")
|
||||
|
||||
chosen_schema = schema or item_name(schemas[0])
|
||||
ui.last_event = f"Selected schema: {chosen_schema}"
|
||||
|
||||
tables_res = await self.mcp.call_tool("list_tables", {"schema": chosen_schema, "page_size": 500})
|
||||
tables = tables_res if isinstance(tables_res, list) else normalize_list(tables_res, ("tables","items","result"))
|
||||
table_names = [item_name(t) for t in tables]
|
||||
if not table_names:
|
||||
raise MCPError(f"No tables returned by MCP list_tables(schema={chosen_schema})")
|
||||
|
||||
user_intent = None
|
||||
try:
|
||||
ig = await self.mcp.call_tool("catalog_get", {"kind": "intent", "key": f"intent/{ui.run_id}"})
|
||||
if isinstance(ig, dict) and ig.get("document"):
|
||||
user_intent = json.loads(ig["document"])
|
||||
except Exception:
|
||||
user_intent = None
|
||||
|
||||
ui.phase = "running"
|
||||
no_progress_streak = 0
|
||||
|
||||
for it in range(1, max_iterations + 1):
|
||||
ui.iteration = it
|
||||
ui.last_event = "Planning tasks…"
|
||||
tasks = await self.planner(chosen_schema, table_names, user_intent)
|
||||
ui.planned_tasks = tasks
|
||||
ui.last_event = f"Planned {len(tasks)} tasks"
|
||||
|
||||
if not tasks:
|
||||
ui.phase = "done"
|
||||
ui.last_event = "No tasks from planner"
|
||||
return
|
||||
|
||||
executed = 0
|
||||
before_insights = ui.insights
|
||||
before_writes = ui.catalog_writes
|
||||
|
||||
for task in tasks:
|
||||
if executed >= tasks_per_iter:
|
||||
break
|
||||
executed += 1
|
||||
|
||||
expert = task.expert
|
||||
ctx = {
|
||||
"run_id": ui.run_id,
|
||||
"schema": task.schema,
|
||||
"table": task.table,
|
||||
"goal": task.goal,
|
||||
"user_intent": user_intent,
|
||||
"note": "Choose minimal tool calls to advance discovery."
|
||||
}
|
||||
|
||||
ui.last_event = f"{expert} ACT: {task.goal}" + (f" ({task.table})" if task.table else "")
|
||||
act = await self.expert_act(expert, ctx)
|
||||
|
||||
tool_results: List[Dict[str, Any]] = []
|
||||
for call in act.tool_calls:
|
||||
ui.last_event = f"MCP tool: {call.name}"
|
||||
ui.tool_calls += 1
|
||||
res = await self.mcp.call_tool(call.name, call.args)
|
||||
tool_results.append({"tool": call.name, "args": call.args, "result": res})
|
||||
|
||||
ui.last_event = f"{expert} REFLECT"
|
||||
ref = await self.expert_reflect(expert, ctx, tool_results)
|
||||
|
||||
if ref.catalog_writes:
|
||||
await self.apply_catalog_writes(ref.catalog_writes)
|
||||
ui.catalog_writes += len(ref.catalog_writes)
|
||||
|
||||
for q in ref.questions_for_user[:1]:
|
||||
payload = {
|
||||
"run_id": ui.run_id,
|
||||
"question_id": q.question_id,
|
||||
"title": q.title,
|
||||
"prompt": q.prompt,
|
||||
"options": q.options,
|
||||
"created_at": now_iso()
|
||||
}
|
||||
await self.mcp.call_tool("catalog_upsert", {
|
||||
"kind": "question",
|
||||
"key": f"question/{ui.run_id}/{q.question_id}",
|
||||
"document": json.dumps(payload, ensure_ascii=False),
|
||||
"tags": f"run:{ui.run_id}"
|
||||
})
|
||||
ui.catalog_writes += 1
|
||||
|
||||
ui.insights += len(ref.insights)
|
||||
|
||||
gained_insights = ui.insights - before_insights
|
||||
gained_writes = ui.catalog_writes - before_writes
|
||||
if gained_insights == 0 and gained_writes == 0:
|
||||
no_progress_streak += 1
|
||||
else:
|
||||
no_progress_streak = 0
|
||||
|
||||
if no_progress_streak >= 2:
|
||||
ui.phase = "done"
|
||||
ui.last_event = "Stopping: diminishing returns"
|
||||
return
|
||||
|
||||
ui.phase = "done"
|
||||
ui.last_event = "Finished: max_iterations reached"
|
||||
|
||||
|
||||
def render(ui: UIState) -> Layout:
|
||||
layout = Layout()
|
||||
|
||||
header = Text()
|
||||
header.append("Database Discovery Agent ", style="bold")
|
||||
header.append(f"(run_id: {ui.run_id})", style="dim")
|
||||
|
||||
status = Table.grid(expand=True)
|
||||
status.add_column(justify="left")
|
||||
status.add_column(justify="right")
|
||||
status.add_row("Phase", f"[bold]{ui.phase}[/bold]")
|
||||
status.add_row("Iteration", str(ui.iteration))
|
||||
status.add_row("Tool calls", str(ui.tool_calls))
|
||||
status.add_row("Catalog writes", str(ui.catalog_writes))
|
||||
status.add_row("Insights", str(ui.insights))
|
||||
|
||||
tasks_table = Table(title="Planned Tasks", expand=True)
|
||||
tasks_table.add_column("Prio", justify="right", width=6)
|
||||
tasks_table.add_column("Expert", width=11)
|
||||
tasks_table.add_column("Goal")
|
||||
tasks_table.add_column("Table", style="dim")
|
||||
|
||||
for t in (ui.planned_tasks or [])[:10]:
|
||||
tasks_table.add_row(f"{t.priority:.2f}", t.expert, t.goal, t.table or "")
|
||||
|
||||
events = Text()
|
||||
if ui.last_event:
|
||||
events.append(ui.last_event, style="white")
|
||||
if ui.last_error:
|
||||
events.append("\n")
|
||||
events.append(ui.last_error, style="bold red")
|
||||
|
||||
layout.split_column(
|
||||
Layout(Panel(header, border_style="cyan"), size=3),
|
||||
Layout(Panel(status, title="Status", border_style="green"), size=8),
|
||||
Layout(Panel(tasks_table, border_style="magenta"), ratio=2),
|
||||
Layout(Panel(events, title="Last event", border_style="yellow"), size=6),
|
||||
)
|
||||
return layout
|
||||
|
||||
|
||||
async def cmd_run(args: argparse.Namespace):
|
||||
console = Console()
|
||||
trace = TraceLogger(args.trace)
|
||||
|
||||
mcp_endpoint = args.mcp_endpoint or os.getenv("MCP_ENDPOINT", "")
|
||||
mcp_token = args.mcp_auth_token or os.getenv("MCP_AUTH_TOKEN")
|
||||
mcp_insecure = args.mcp_insecure_tls or (os.getenv("MCP_INSECURE_TLS", "0") in ("1","true","TRUE","yes","YES"))
|
||||
|
||||
llm_base = args.llm_base_url or os.getenv("LLM_BASE_URL", "https://api.openai.com")
|
||||
llm_key = args.llm_api_key or os.getenv("LLM_API_KEY", "")
|
||||
llm_model = args.llm_model or os.getenv("LLM_MODEL", "gpt-4o-mini")
|
||||
llm_insecure = args.llm_insecure_tls or (os.getenv("LLM_INSECURE_TLS", "0") in ("1","true","TRUE","yes","YES"))
|
||||
|
||||
if not mcp_endpoint:
|
||||
console.print("[bold red]MCP endpoint is required (set MCP_ENDPOINT or --mcp-endpoint)[/bold red]")
|
||||
raise SystemExit(2)
|
||||
|
||||
if "openai.com" in llm_base and not llm_key:
|
||||
console.print("[bold red]LLM_API_KEY is required for OpenAI[/bold red]")
|
||||
raise SystemExit(2)
|
||||
|
||||
run_id = args.run_id or str(uuid.uuid4())
|
||||
ui = UIState(run_id=run_id)
|
||||
|
||||
mcp = MCPClient(mcp_endpoint, mcp_token, trace, insecure_tls=mcp_insecure)
|
||||
llm = LLMClient(llm_base, llm_key, llm_model, trace, insecure_tls=llm_insecure)
|
||||
agent = Agent(mcp, llm, trace, debug=args.debug)
|
||||
|
||||
async def runner():
|
||||
try:
|
||||
await agent.run(ui, args.schema, args.max_iterations, args.tasks_per_iter)
|
||||
except Exception as e:
|
||||
ui.phase = "error"
|
||||
ui.last_error = f"{type(e).__name__}: {e}"
|
||||
trace.write({"type": "error", "error": ui.last_error})
|
||||
if args.debug:
|
||||
tb = traceback.format_exc()
|
||||
trace.write({"type": "error.traceback", "traceback": tb})
|
||||
ui.last_error += "\n" + tb
|
||||
finally:
|
||||
await mcp.close()
|
||||
await llm.close()
|
||||
|
||||
task = asyncio.create_task(runner())
|
||||
with Live(render(ui), refresh_per_second=8, console=console):
|
||||
while not task.done():
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
console.print(render(ui))
|
||||
if ui.phase == "error":
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
async def cmd_intent(args: argparse.Namespace):
|
||||
console = Console()
|
||||
trace = TraceLogger(args.trace)
|
||||
|
||||
mcp_endpoint = args.mcp_endpoint or os.getenv("MCP_ENDPOINT", "")
|
||||
mcp_token = args.mcp_auth_token or os.getenv("MCP_AUTH_TOKEN")
|
||||
mcp_insecure = args.mcp_insecure_tls or (os.getenv("MCP_INSECURE_TLS", "0") in ("1","true","TRUE","yes","YES"))
|
||||
|
||||
if not mcp_endpoint:
|
||||
console.print("[bold red]MCP endpoint is required[/bold red]")
|
||||
raise SystemExit(2)
|
||||
|
||||
payload = {
|
||||
"run_id": args.run_id,
|
||||
"audience": args.audience,
|
||||
"goals": args.goals,
|
||||
"constraints": {},
|
||||
"updated_at": now_iso()
|
||||
}
|
||||
for kv in (args.constraint or []):
|
||||
if "=" in kv:
|
||||
k, v = kv.split("=", 1)
|
||||
payload["constraints"][k] = v
|
||||
|
||||
mcp = MCPClient(mcp_endpoint, mcp_token, trace, insecure_tls=mcp_insecure)
|
||||
try:
|
||||
await mcp.call_tool("catalog_upsert", {
|
||||
"kind": "intent",
|
||||
"key": f"intent/{args.run_id}",
|
||||
"document": json.dumps(payload, ensure_ascii=False),
|
||||
"tags": f"run:{args.run_id}"
|
||||
})
|
||||
console.print("[green]Intent stored[/green]")
|
||||
finally:
|
||||
await mcp.close()
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(prog="discover_cli", description="Database Discovery Agent (Async CLI)")
|
||||
sub = p.add_subparsers(dest="cmd", required=True)
|
||||
|
||||
common = argparse.ArgumentParser(add_help=False)
|
||||
common.add_argument("--mcp-endpoint", default=None, help="MCP JSON-RPC endpoint (or MCP_ENDPOINT env)")
|
||||
common.add_argument("--mcp-auth-token", default=None, help="MCP auth token (or MCP_AUTH_TOKEN env)")
|
||||
common.add_argument("--mcp-insecure-tls", action="store_true", help="Disable MCP TLS verification (like curl -k)")
|
||||
common.add_argument("--llm-base-url", default=None, help="OpenAI-compatible base URL (or LLM_BASE_URL env)")
|
||||
common.add_argument("--llm-api-key", default=None, help="LLM API key (or LLM_API_KEY env)")
|
||||
common.add_argument("--llm-model", default=None, help="LLM model (or LLM_MODEL env)")
|
||||
common.add_argument("--llm-insecure-tls", action="store_true", help="Disable LLM TLS verification")
|
||||
common.add_argument("--trace", default=None, help="Write JSONL trace to this file")
|
||||
common.add_argument("--debug", action="store_true", help="Show stack traces")
|
||||
|
||||
prun = sub.add_parser("run", parents=[common], help="Run discovery")
|
||||
prun.add_argument("--run-id", default=None, help="Optional run id (uuid). If omitted, generated.")
|
||||
prun.add_argument("--schema", default=None, help="Optional schema to focus on")
|
||||
prun.add_argument("--max-iterations", type=int, default=6)
|
||||
prun.add_argument("--tasks-per-iter", type=int, default=3)
|
||||
prun.set_defaults(func=cmd_run)
|
||||
|
||||
pint = sub.add_parser("intent", parents=[common], help="Set user intent for a run (stored in MCP catalog)")
|
||||
pint.add_argument("--run-id", required=True)
|
||||
pint.add_argument("--audience", default="mixed")
|
||||
pint.add_argument("--goals", nargs="*", default=["qna"])
|
||||
pint.add_argument("--constraint", action="append", help="constraint as key=value; repeatable")
|
||||
pint.set_defaults(func=cmd_intent)
|
||||
|
||||
return p
|
||||
|
||||
|
||||
def main():
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
asyncio.run(args.func(args))
|
||||
except KeyboardInterrupt:
|
||||
Console().print("\n[yellow]Interrupted[/yellow]")
|
||||
raise SystemExit(130)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
httpx==0.27.0
|
||||
pydantic==2.8.2
|
||||
python-dotenv==1.0.1
|
||||
rich==13.7.1
|
||||
Loading…
Reference in new issue