agent_graph.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. from __future__ import annotations
  2. import json
  3. import time
  4. from datetime import datetime, timezone
  5. from typing import Any, Literal
  6. from langgraph.graph import END, START, StateGraph
  7. from .config import RuntimeConfig
  8. from .mcp_client import McpClient
  9. from .models import (
  10. AgentResultTable,
  11. AgentRunResponse,
  12. AgentState,
  13. AgentTraceEvent,
  14. ChatMessage,
  15. McpToolDescriptor,
  16. PlannedToolCall,
  17. ToolObservation,
  18. )
  19. Route = Literal["execute_tools", "summarize", "chat"]
  20. class SmqjhAgentGraph:
  21. def __init__(self, config: RuntimeConfig) -> None:
  22. self.config = config
  23. self.mcp = McpClient(config.mcp.url, config.mcp.token)
  24. self.graph = self._build_graph()
  25. def health(self) -> Any:
  26. return self.mcp.status()
  27. def tools(self) -> list[McpToolDescriptor]:
  28. return self._load_runnable_tools()
  29. def run(self, message: str, history: list[ChatMessage] | None = None) -> AgentRunResponse:
  30. state: AgentState = {
  31. "message": message,
  32. "history": history or [],
  33. "visitedSignatures": [],
  34. "steps": 0,
  35. "trace": [_trace("system", "收到用户请求", message)],
  36. }
  37. result = self.graph.invoke(state)
  38. final = result.get("final")
  39. if final:
  40. return final
  41. observations = result.get("observations", [])
  42. trace_events = result.get("trace", [])
  43. return {
  44. "content": "这次没有生成有效回答,请稍后重试。",
  45. "model": "none",
  46. "usedMcp": False,
  47. "steps": int(result.get("steps", 0)),
  48. "toolCalls": observations,
  49. "tables": _extract_tables(observations),
  50. "trace": trace_events,
  51. }
  52. def _build_graph(self):
  53. workflow = StateGraph(AgentState)
  54. workflow.add_node("load_tools", self._load_tools_node)
  55. workflow.add_node("plan", self._plan_node)
  56. workflow.add_node("execute_tools", self._execute_tools_node)
  57. workflow.add_node("summarize", self._summarize_node)
  58. workflow.add_node("chat", self._chat_node)
  59. workflow.add_edge(START, "load_tools")
  60. workflow.add_edge("load_tools", "plan")
  61. workflow.add_conditional_edges(
  62. "plan",
  63. self._route_after_plan,
  64. {
  65. "execute_tools": "execute_tools",
  66. "summarize": "summarize",
  67. "chat": "chat",
  68. },
  69. )
  70. workflow.add_edge("execute_tools", "plan")
  71. workflow.add_edge("summarize", END)
  72. workflow.add_edge("chat", END)
  73. return workflow.compile()
  74. def _load_tools_node(self, state: AgentState) -> AgentState:
  75. try:
  76. tools = self._load_runnable_tools()
  77. return {
  78. "tools": tools,
  79. "trace": _append_trace(state, _trace("system", "MCP 工具加载完成", f"可用工具 {len(tools)} 个")),
  80. }
  81. except Exception as error:
  82. return {
  83. "tools": [],
  84. "trace": _append_trace(state, _trace("error", "MCP 工具加载失败", _error_message(error))),
  85. }
  86. def _plan_node(self, state: AgentState) -> AgentState:
  87. if int(state.get("steps", 0)) >= self.config.max_steps:
  88. return {
  89. "plannedToolCalls": [],
  90. "trace": _append_trace(state, _trace("plan", "达到最大工具步数", f"maxSteps={self.config.max_steps}")),
  91. }
  92. try:
  93. result = self.mcp.call_tool(
  94. "smqjh.ai.tool.plan",
  95. {
  96. "request": self._assistant_request(state),
  97. "tools": state.get("tools", []),
  98. "observations": state.get("observations", []),
  99. },
  100. )
  101. record = _as_dict(result.get("structuredContent"))
  102. raw_calls = record.get("toolCalls")
  103. raw_call_items = raw_calls if isinstance(raw_calls, list) else []
  104. visited = set(state.get("visitedSignatures", []))
  105. planned = [
  106. call
  107. for call in (_normalize_tool_call(item, state.get("message", "")) for item in raw_call_items)
  108. if call and _tool_signature(call["name"], call["arguments"]) not in visited
  109. ][:3]
  110. return {
  111. "plannedToolCalls": planned,
  112. "trace": _append_trace(
  113. state,
  114. _trace(
  115. "plan",
  116. "DeepSeek 已规划工具调用" if planned else "DeepSeek 判断无需继续调用工具",
  117. _reason_text(record),
  118. {"toolCalls": planned},
  119. ),
  120. ),
  121. }
  122. except Exception as error:
  123. return {
  124. "plannedToolCalls": [],
  125. "trace": _append_trace(state, _trace("error", "DeepSeek 工具规划失败", _error_message(error))),
  126. }
  127. def _execute_tools_node(self, state: AgentState) -> AgentState:
  128. by_name = {tool.get("name", ""): tool for tool in state.get("tools", [])}
  129. observations: list[ToolObservation] = []
  130. visited = set(state.get("visitedSignatures", []))
  131. for call in state.get("plannedToolCalls", []):
  132. requested_name = call["name"]
  133. arguments = call["arguments"]
  134. resolved_name = _resolve_tool_name(requested_name, by_name)
  135. signature = _tool_signature(resolved_name or requested_name, arguments)
  136. visited.add(signature)
  137. if not resolved_name:
  138. observations.append(
  139. {
  140. "name": requested_name,
  141. "source": "mcp",
  142. "arguments": arguments,
  143. "ok": False,
  144. "error": "DeepSeek 选择了未注册 MCP 工具",
  145. }
  146. )
  147. continue
  148. started = time.perf_counter()
  149. try:
  150. result = self.mcp.call_tool(resolved_name, arguments)
  151. observations.append(
  152. {
  153. "name": resolved_name,
  154. "source": "mcp",
  155. "arguments": arguments,
  156. "ok": True,
  157. "result": result.get("structuredContent", result),
  158. "durationMs": int((time.perf_counter() - started) * 1000),
  159. }
  160. )
  161. except Exception as error:
  162. observations.append(
  163. {
  164. "name": resolved_name,
  165. "source": "mcp",
  166. "arguments": arguments,
  167. "ok": False,
  168. "error": _error_message(error),
  169. "durationMs": int((time.perf_counter() - started) * 1000),
  170. }
  171. )
  172. trace_events = state.get("trace", [])
  173. for item in observations:
  174. trace_events = [
  175. *trace_events,
  176. _trace(
  177. "tool" if item.get("ok") else "error",
  178. f"工具执行完成:{item['name']}" if item.get("ok") else f"工具执行失败:{item['name']}",
  179. f"{item.get('durationMs', 0)}ms" if item.get("ok") else item.get("error", ""),
  180. {"arguments": item.get("arguments"), "result": _summarize_result(item.get("result"))},
  181. ),
  182. ]
  183. return {
  184. "observations": [*state.get("observations", []), *observations],
  185. "plannedToolCalls": [],
  186. "visitedSignatures": list(visited),
  187. "steps": int(state.get("steps", 0)) + 1,
  188. "trace": trace_events,
  189. }
  190. def _summarize_node(self, state: AgentState) -> AgentState:
  191. observations = state.get("observations", [])
  192. tables = _extract_tables(observations)
  193. try:
  194. result = self.mcp.call_tool(
  195. "smqjh.ai.tool.summarize",
  196. {
  197. "request": self._assistant_request(state),
  198. "observations": observations,
  199. },
  200. )
  201. record = _as_dict(result.get("structuredContent"))
  202. content = str(record.get("content") or "").strip() or _fallback_summary(observations, tables)
  203. model = str(record.get("model") or "mcp-deepseek")
  204. trace_events = _append_trace(state, _trace("summary", "结果总结完成", f"{len(tables)} 个表格"))
  205. return {
  206. "final": {
  207. "content": content,
  208. "model": model,
  209. "usedMcp": True,
  210. "steps": int(state.get("steps", 0)),
  211. "toolCalls": observations,
  212. "tables": tables,
  213. "trace": trace_events,
  214. },
  215. "trace": trace_events,
  216. }
  217. except Exception as error:
  218. trace_events = _append_trace(state, _trace("error", "结果总结失败,已使用工具结果兜底", _error_message(error)))
  219. return {
  220. "final": {
  221. "content": _fallback_summary(observations, tables),
  222. "model": "tool-only",
  223. "usedMcp": True,
  224. "steps": int(state.get("steps", 0)),
  225. "toolCalls": observations,
  226. "tables": tables,
  227. "trace": trace_events,
  228. },
  229. "trace": trace_events,
  230. }
  231. def _chat_node(self, state: AgentState) -> AgentState:
  232. try:
  233. result = self.mcp.call_tool("smqjh.ai.chat", {"request": self._assistant_request(state)})
  234. record = _as_dict(result.get("structuredContent"))
  235. content = str(record.get("content") or "").strip() or "DeepSeek 没有返回有效内容。"
  236. model = str(record.get("model") or "mcp-deepseek")
  237. trace_events = _append_trace(state, _trace("chat", "普通对话完成", model))
  238. return {
  239. "final": {
  240. "content": content,
  241. "model": model,
  242. "usedMcp": True,
  243. "steps": int(state.get("steps", 0)),
  244. "toolCalls": [],
  245. "tables": [],
  246. "trace": trace_events,
  247. },
  248. "trace": trace_events,
  249. }
  250. except Exception as error:
  251. trace_events = _append_trace(state, _trace("error", "普通对话失败", _error_message(error)))
  252. return {
  253. "final": {
  254. "content": f"暂时无法完成回答:{_error_message(error)}",
  255. "model": "none",
  256. "usedMcp": False,
  257. "steps": int(state.get("steps", 0)),
  258. "toolCalls": [],
  259. "tables": [],
  260. "trace": trace_events,
  261. },
  262. "trace": trace_events,
  263. }
  264. def _route_after_plan(self, state: AgentState) -> Route:
  265. if state.get("plannedToolCalls") and int(state.get("steps", 0)) < self.config.max_steps:
  266. return "execute_tools"
  267. if state.get("observations"):
  268. return "summarize"
  269. return "chat"
  270. def _assistant_request(self, state: AgentState) -> dict[str, Any]:
  271. return {
  272. "message": state.get("message", ""),
  273. "history": state.get("history", []),
  274. "environmentName": self.config.environment_name,
  275. "baseUrl": self.config.base_url,
  276. "authenticated": True,
  277. "username": "web-agent",
  278. }
  279. def _load_runnable_tools(self) -> list[McpToolDescriptor]:
  280. return [tool for tool in self.mcp.list_tools() if not str(tool.get("name", "")).startswith("smqjh.ai.")]
  281. def _normalize_tool_call(value: Any, message: str) -> PlannedToolCall | None:
  282. record = _as_dict(value)
  283. name = str(record.get("name") or "").strip()
  284. if not name:
  285. return None
  286. return {
  287. "name": name,
  288. "arguments": _repair_tool_arguments(name, _as_dict(record.get("arguments")), message),
  289. }
  290. def _repair_tool_arguments(name: str, arguments: dict[str, Any], message: str) -> dict[str, Any]:
  291. if name == "smqjh.database.smart.query" and not isinstance(arguments.get("question"), str):
  292. return {**arguments, "question": message}
  293. if name == "smqjh.product.lookup.summary":
  294. current = str(arguments.get("productKeyword") or "").strip()
  295. return {**arguments, "productKeyword": current or _extract_product_keyword(message)}
  296. return arguments
  297. def _extract_product_keyword(message: str) -> str:
  298. text = message
  299. for token in [
  300. "帮我",
  301. "麻烦",
  302. "查询一下",
  303. "查一下",
  304. "查询",
  305. "查看",
  306. "当前",
  307. "业务系统",
  308. "系统里面",
  309. "系统里",
  310. "后台",
  311. "我方",
  312. "我们的",
  313. "商品库",
  314. "商品表",
  315. "商品描述是什么",
  316. "商品描述",
  317. "描述是什么",
  318. "描述",
  319. "价格是多少",
  320. "价格",
  321. "定价",
  322. "是多少",
  323. "是什么",
  324. "呢",
  325. ]:
  326. text = text.replace(token, " ")
  327. return " ".join(text.replace(",", " ").replace("。", " ").replace("?", " ").replace("?", " ").split())
  328. def _resolve_tool_name(name: str, by_name: dict[str, McpToolDescriptor]) -> str | None:
  329. candidates = [
  330. name,
  331. name.removeprefix("smqjh.") if name.startswith("smqjh.") else f"smqjh.{name}",
  332. "smqjh.product.lookup.summary" if name == "product.lookup.summary" else "",
  333. "smqjh.order.count.query" if name == "order.count.query" else "",
  334. "smqjh.database.smart.query" if name == "database.smart.query" else "",
  335. "smqjh.database.readonly.query" if name == "database.readonly.query" else "",
  336. "smqjh.cloud.health" if name == "cloud.health" else "",
  337. ]
  338. return next((candidate for candidate in candidates if candidate and candidate in by_name), None)
  339. def _extract_tables(observations: list[ToolObservation]) -> list[AgentResultTable]:
  340. tables: list[AgentResultTable] = []
  341. for observation in observations:
  342. if not observation.get("ok"):
  343. continue
  344. tables.extend(_extract_tables_from_value(_tool_title(observation.get("name", "")), observation.get("result")))
  345. return tables[:6]
  346. def _extract_tables_from_value(title: str, value: Any) -> list[AgentResultTable]:
  347. tables: list[AgentResultTable] = []
  348. seen: set[int] = set()
  349. def visit(node: Any, current_title: str, depth: int) -> None:
  350. if depth > 5 or not isinstance(node, dict):
  351. return
  352. node_id = id(node)
  353. if node_id in seen:
  354. return
  355. seen.add(node_id)
  356. add_rows_table(current_title, node, "rows")
  357. add_rows_table(current_title, node, "comparisonRows")
  358. for key in ("data", "structuredContent", "result", "record"):
  359. if key in node:
  360. visit(node[key], current_title, depth + 1)
  361. def add_rows_table(current_title: str, record: dict[str, Any], rows_key: str) -> None:
  362. rows = record.get(rows_key)
  363. if not isinstance(rows, list) or not all(isinstance(row, dict) for row in rows):
  364. return
  365. row_records = [_normalize_row(row) for row in rows[:200]]
  366. provided_columns = record.get("columns")
  367. if isinstance(provided_columns, list) and provided_columns:
  368. columns = [str(column) for column in provided_columns]
  369. else:
  370. columns = []
  371. for row in row_records:
  372. for key in row:
  373. if key not in columns:
  374. columns.append(key)
  375. if not columns:
  376. return
  377. tables.append(
  378. {
  379. "title": _build_table_title(current_title, record, rows_key),
  380. "columns": columns,
  381. "rows": row_records,
  382. }
  383. )
  384. visit(value, title, 0)
  385. return tables
  386. def _fallback_summary(observations: list[ToolObservation], tables: list[AgentResultTable]) -> str:
  387. ok_count = len([item for item in observations if item.get("ok")])
  388. failed = [item for item in observations if not item.get("ok")]
  389. lines = [f"已执行 {len(observations)} 个 MCP 工具,成功 {ok_count} 个。"]
  390. if tables:
  391. lines.append("结果已整理为下方表格。")
  392. if failed:
  393. lines.append("失败工具:" + ";".join(f"{item.get('name')}:{item.get('error') or '调用失败'}" for item in failed))
  394. evidence = _find_evidence(observations)
  395. if evidence:
  396. lines.append(f"依据:{evidence}")
  397. return "\n\n".join(lines)
  398. def _find_evidence(observations: list[ToolObservation]) -> str:
  399. for observation in observations:
  400. sql = _find_key_value(observation.get("result"), "executedSql", 0)
  401. if isinstance(sql, str) and sql.strip():
  402. return sql.strip()
  403. evidence = _find_key_value(observation.get("result"), "evidence", 0)
  404. if isinstance(evidence, str) and evidence.strip():
  405. return evidence.strip()
  406. return ""
  407. def _find_key_value(value: Any, key: str, depth: int) -> Any:
  408. if depth > 4 or not isinstance(value, dict):
  409. return None
  410. if key in value:
  411. return value[key]
  412. for child in value.values():
  413. result = _find_key_value(child, key, depth + 1)
  414. if result is not None:
  415. return result
  416. return None
  417. def _trace(phase: str, title: str, detail: str = "", data: Any = None) -> AgentTraceEvent:
  418. event: AgentTraceEvent = {
  419. "at": datetime.now(timezone.utc).isoformat(),
  420. "phase": phase, # type: ignore[typeddict-item]
  421. "title": title,
  422. }
  423. if detail:
  424. event["detail"] = detail
  425. if data is not None:
  426. event["data"] = data
  427. return event
  428. def _append_trace(state: AgentState, event: AgentTraceEvent) -> list[AgentTraceEvent]:
  429. return [*state.get("trace", []), event]
  430. def _reason_text(record: dict[str, Any]) -> str:
  431. reason = record.get("reason")
  432. return reason.strip() if isinstance(reason, str) else ""
  433. def _summarize_result(value: Any) -> Any:
  434. record = _as_dict(value)
  435. if isinstance(record.get("rowCount"), int):
  436. return {"rowCount": record.get("rowCount"), "title": record.get("title"), "summary": record.get("summary")}
  437. if isinstance(record.get("rows"), list):
  438. return {"rows": len(record["rows"])}
  439. return value
  440. def _build_table_title(base_title: str, record: dict[str, Any], rows_key: str) -> str:
  441. title = str(record.get("title") or base_title).strip()
  442. count = f"({record['rowCount']} 行)" if "rowCount" in record else ""
  443. return f"{title}:价格对比" if rows_key == "comparisonRows" else f"{title}{count}"
  444. def _tool_title(name: str) -> str:
  445. titles = {
  446. "smqjh.config.get": "运行配置",
  447. "smqjh.cloud.health": "网关连通检查",
  448. "smqjh.schema.search": "业务表搜索结果",
  449. "smqjh.schema.getTable": "业务表说明",
  450. "smqjh.schema.businessRules": "业务规则",
  451. "smqjh.database.readonly.query": "数据库只读查询结果",
  452. "smqjh.database.smart.query": "智能数据库查询结果",
  453. "smqjh.order.count.query": "订单统计结果",
  454. "smqjh.product.lookup.summary": "商品资料查询结果",
  455. "smqjh.settlement.enterprise.list": "月结企业清单",
  456. "smqjh.settlement.monthly.plan": "企业月结计划",
  457. }
  458. return titles.get(name, name)
  459. def _tool_signature(name: str, arguments: dict[str, Any]) -> str:
  460. return f"{name}:{json.dumps(arguments, ensure_ascii=False, sort_keys=True, default=str)}"
  461. def _as_dict(value: Any) -> dict[str, Any]:
  462. return value if isinstance(value, dict) else {}
  463. def _normalize_row(row: dict[str, Any]) -> dict[str, str]:
  464. return {str(key): "" if value is None else str(value) for key, value in row.items()}
  465. def _error_message(error: Exception) -> str:
  466. return str(error) or error.__class__.__name__