agent_graph.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603
  1. from __future__ import annotations
  2. import json
  3. import re
  4. import time
  5. from contextvars import ContextVar
  6. from datetime import datetime, timezone
  7. from typing import Any, Callable, Literal
  8. from langgraph.graph import END, START, StateGraph
  9. from .config import RuntimeConfig
  10. from .mcp_client import McpClient
  11. from .models import (
  12. AgentResultTable,
  13. AgentRunResponse,
  14. AgentState,
  15. AgentTraceEvent,
  16. ChatMessage,
  17. McpToolDescriptor,
  18. PlannedToolCall,
  19. ToolObservation,
  20. )
  21. Route = Literal["execute_tools", "summarize", "chat"]
  22. TraceSink = Callable[[AgentTraceEvent], None]
  23. _TRACE_SINK: ContextVar[TraceSink | None] = ContextVar("smqjh_agent_trace_sink", default=None)
  24. class SmqjhAgentGraph:
  25. def __init__(self, config: RuntimeConfig) -> None:
  26. self.config = config
  27. self.mcp = McpClient(config.mcp.url, config.mcp.token)
  28. self.graph = self._build_graph()
  29. def health(self) -> Any:
  30. return self.mcp.status()
  31. def tools(self) -> list[McpToolDescriptor]:
  32. return self._load_runnable_tools()
  33. def run(
  34. self,
  35. message: str,
  36. history: list[ChatMessage] | None = None,
  37. emit: TraceSink | None = None,
  38. ) -> AgentRunResponse:
  39. token = _TRACE_SINK.set(emit)
  40. try:
  41. state: AgentState = {
  42. "message": message,
  43. "history": history or [],
  44. "visitedSignatures": [],
  45. "steps": 0,
  46. "trace": [_trace("system", "已收到请求", message)],
  47. }
  48. result = self.graph.invoke(state)
  49. final = result.get("final")
  50. if final:
  51. return final
  52. observations = result.get("observations", [])
  53. trace_events = result.get("trace", [])
  54. return {
  55. "content": "这次没有生成有效回答,请稍后重试。",
  56. "model": "none",
  57. "usedMcp": False,
  58. "steps": int(result.get("steps", 0)),
  59. "toolCalls": observations,
  60. "tables": _extract_tables(observations),
  61. "trace": trace_events,
  62. }
  63. finally:
  64. _TRACE_SINK.reset(token)
  65. def _build_graph(self):
  66. workflow = StateGraph(AgentState)
  67. workflow.add_node("load_tools", self._load_tools_node)
  68. workflow.add_node("plan", self._plan_node)
  69. workflow.add_node("execute_tools", self._execute_tools_node)
  70. workflow.add_node("summarize", self._summarize_node)
  71. workflow.add_node("chat", self._chat_node)
  72. workflow.add_edge(START, "load_tools")
  73. workflow.add_edge("load_tools", "plan")
  74. workflow.add_conditional_edges(
  75. "plan",
  76. self._route_after_plan,
  77. {
  78. "execute_tools": "execute_tools",
  79. "summarize": "summarize",
  80. "chat": "chat",
  81. },
  82. )
  83. workflow.add_edge("execute_tools", "plan")
  84. workflow.add_edge("summarize", END)
  85. workflow.add_edge("chat", END)
  86. return workflow.compile()
  87. def _load_tools_node(self, state: AgentState) -> AgentState:
  88. start = _trace("system", "正在连接 MCP", "加载可用业务工具和数据库只读查询能力")
  89. try:
  90. tools = self._load_runnable_tools()
  91. done = _trace("system", "MCP 工具加载完成", f"可用工具 {len(tools)} 个")
  92. return {
  93. "tools": tools,
  94. "trace": _append_trace(state, start, done),
  95. }
  96. except Exception as error:
  97. failed = _trace("error", "MCP 工具加载失败", _error_message(error))
  98. return {
  99. "tools": [],
  100. "trace": _append_trace(state, start, failed),
  101. }
  102. def _plan_node(self, state: AgentState) -> AgentState:
  103. if int(state.get("steps", 0)) >= self.config.max_steps:
  104. return {
  105. "plannedToolCalls": [],
  106. "trace": _append_trace(
  107. state,
  108. _trace("plan", "达到最大工具步数", f"maxSteps={self.config.max_steps}"),
  109. ),
  110. }
  111. start = _trace("plan", "正在分析需求", "DeepSeek 正在判断是否需要查库或调用业务工具")
  112. try:
  113. result = self.mcp.call_tool(
  114. "smqjh.ai.tool.plan",
  115. {
  116. "request": self._assistant_request(state),
  117. "tools": state.get("tools", []),
  118. "observations": state.get("observations", []),
  119. },
  120. )
  121. record = _as_dict(result.get("structuredContent"))
  122. raw_calls = record.get("toolCalls")
  123. raw_call_items = raw_calls if isinstance(raw_calls, list) else []
  124. visited = set(state.get("visitedSignatures", []))
  125. planned = [
  126. call
  127. for call in (_normalize_tool_call(item, state.get("message", "")) for item in raw_call_items)
  128. if call and _tool_signature(call["name"], call["arguments"]) not in visited
  129. ][:3]
  130. done = _trace(
  131. "plan",
  132. "DeepSeek 已规划工具调用" if planned else "DeepSeek 判断无需继续调用工具",
  133. _reason_text(record),
  134. {"toolCalls": planned},
  135. )
  136. return {
  137. "plannedToolCalls": planned,
  138. "trace": _append_trace(state, start, done),
  139. }
  140. except Exception as error:
  141. failed = _trace("error", "DeepSeek 工具规划失败", _error_message(error))
  142. return {
  143. "plannedToolCalls": [],
  144. "trace": _append_trace(state, start, failed),
  145. }
  146. def _execute_tools_node(self, state: AgentState) -> AgentState:
  147. by_name = {tool.get("name", ""): tool for tool in state.get("tools", [])}
  148. observations: list[ToolObservation] = []
  149. visited = set(state.get("visitedSignatures", []))
  150. trace_events = list(state.get("trace", []))
  151. for call in state.get("plannedToolCalls", []):
  152. requested_name = call["name"]
  153. arguments = call["arguments"]
  154. resolved_name = _resolve_tool_name(requested_name, by_name)
  155. signature = _tool_signature(resolved_name or requested_name, arguments)
  156. visited.add(signature)
  157. if not resolved_name:
  158. observation: ToolObservation = {
  159. "name": requested_name,
  160. "source": "mcp",
  161. "arguments": arguments,
  162. "ok": False,
  163. "error": "DeepSeek 选择了未注册 MCP 工具",
  164. }
  165. observations.append(observation)
  166. trace_events.append(_trace("error", f"工具未注册:{requested_name}", observation["error"]))
  167. continue
  168. trace_events.append(_trace("tool", f"正在执行:{resolved_name}", _tool_title(resolved_name)))
  169. started = time.perf_counter()
  170. try:
  171. result = self.mcp.call_tool(resolved_name, arguments)
  172. observation = {
  173. "name": resolved_name,
  174. "source": "mcp",
  175. "arguments": arguments,
  176. "ok": True,
  177. "result": result.get("structuredContent", result),
  178. "durationMs": int((time.perf_counter() - started) * 1000),
  179. }
  180. observations.append(observation)
  181. trace_events.append(
  182. _trace(
  183. "tool",
  184. f"工具执行完成:{resolved_name}",
  185. f"{observation.get('durationMs', 0)}ms",
  186. {"arguments": arguments, "result": _summarize_result(observation.get("result"))},
  187. )
  188. )
  189. except Exception as error:
  190. observation = {
  191. "name": resolved_name,
  192. "source": "mcp",
  193. "arguments": arguments,
  194. "ok": False,
  195. "error": _error_message(error),
  196. "durationMs": int((time.perf_counter() - started) * 1000),
  197. }
  198. observations.append(observation)
  199. trace_events.append(
  200. _trace(
  201. "error",
  202. f"工具执行失败:{resolved_name}",
  203. observation.get("error", ""),
  204. {"arguments": arguments},
  205. )
  206. )
  207. return {
  208. "observations": [*state.get("observations", []), *observations],
  209. "plannedToolCalls": [],
  210. "visitedSignatures": list(visited),
  211. "steps": int(state.get("steps", 0)) + 1,
  212. "trace": trace_events,
  213. }
  214. def _summarize_node(self, state: AgentState) -> AgentState:
  215. observations = state.get("observations", [])
  216. tables = _extract_tables(observations)
  217. start = _trace("summary", "正在整理结果", "将工具结果整理成结论、依据和表格")
  218. try:
  219. result = self.mcp.call_tool(
  220. "smqjh.ai.tool.summarize",
  221. {
  222. "request": self._assistant_request(state),
  223. "observations": observations,
  224. },
  225. )
  226. record = _as_dict(result.get("structuredContent"))
  227. content = str(record.get("content") or "").strip() or _fallback_summary(observations, tables)
  228. model = str(record.get("model") or "mcp-deepseek")
  229. done = _trace("summary", "结果总结完成", f"{len(tables)} 个表格")
  230. trace_events = _append_trace(state, start, done)
  231. return {
  232. "final": {
  233. "content": content,
  234. "model": model,
  235. "usedMcp": True,
  236. "steps": int(state.get("steps", 0)),
  237. "toolCalls": observations,
  238. "tables": tables,
  239. "trace": trace_events,
  240. },
  241. "trace": trace_events,
  242. }
  243. except Exception as error:
  244. failed = _trace("error", "结果总结失败,已使用工具结果兜底", _error_message(error))
  245. trace_events = _append_trace(state, start, failed)
  246. return {
  247. "final": {
  248. "content": _fallback_summary(observations, tables),
  249. "model": "tool-only",
  250. "usedMcp": True,
  251. "steps": int(state.get("steps", 0)),
  252. "toolCalls": observations,
  253. "tables": tables,
  254. "trace": trace_events,
  255. },
  256. "trace": trace_events,
  257. }
  258. def _chat_node(self, state: AgentState) -> AgentState:
  259. start = _trace("chat", "正在生成回复", "没有继续调用业务工具,进入普通对话回答")
  260. try:
  261. result = self.mcp.call_tool("smqjh.ai.chat", {"request": self._assistant_request(state)})
  262. record = _as_dict(result.get("structuredContent"))
  263. content = str(record.get("content") or "").strip() or "DeepSeek 没有返回有效内容。"
  264. model = str(record.get("model") or "mcp-deepseek")
  265. done = _trace("chat", "普通对话完成", model)
  266. trace_events = _append_trace(state, start, done)
  267. return {
  268. "final": {
  269. "content": content,
  270. "model": model,
  271. "usedMcp": True,
  272. "steps": int(state.get("steps", 0)),
  273. "toolCalls": [],
  274. "tables": [],
  275. "trace": trace_events,
  276. },
  277. "trace": trace_events,
  278. }
  279. except Exception as error:
  280. failed = _trace("error", "普通对话失败", _error_message(error))
  281. trace_events = _append_trace(state, start, failed)
  282. return {
  283. "final": {
  284. "content": f"暂时无法完成回答:{_error_message(error)}",
  285. "model": "none",
  286. "usedMcp": False,
  287. "steps": int(state.get("steps", 0)),
  288. "toolCalls": [],
  289. "tables": [],
  290. "trace": trace_events,
  291. },
  292. "trace": trace_events,
  293. }
  294. def _route_after_plan(self, state: AgentState) -> Route:
  295. if state.get("plannedToolCalls") and int(state.get("steps", 0)) < self.config.max_steps:
  296. return "execute_tools"
  297. if state.get("observations"):
  298. return "summarize"
  299. return "chat"
  300. def _assistant_request(self, state: AgentState) -> dict[str, Any]:
  301. return {
  302. "message": state.get("message", ""),
  303. "history": state.get("history", []),
  304. "environmentName": self.config.environment_name,
  305. "baseUrl": self.config.base_url,
  306. "authenticated": True,
  307. "username": "web-agent",
  308. }
  309. def _load_runnable_tools(self) -> list[McpToolDescriptor]:
  310. return [tool for tool in self.mcp.list_tools() if not str(tool.get("name", "")).startswith("smqjh.ai.")]
  311. def _normalize_tool_call(value: Any, message: str) -> PlannedToolCall | None:
  312. record = _as_dict(value)
  313. name = str(record.get("name") or "").strip()
  314. if not name:
  315. return None
  316. return {
  317. "name": name,
  318. "arguments": _repair_tool_arguments(name, _as_dict(record.get("arguments")), message),
  319. }
  320. def _repair_tool_arguments(name: str, arguments: dict[str, Any], message: str) -> dict[str, Any]:
  321. if name == "smqjh.database.smart.query" and not isinstance(arguments.get("question"), str):
  322. return {**arguments, "question": message}
  323. if name == "smqjh.product.lookup.summary":
  324. current = str(arguments.get("productKeyword") or "").strip()
  325. return {**arguments, "productKeyword": current or _extract_product_keyword(message)}
  326. if name in {
  327. "smqjh.settlement.monthly.plan",
  328. "smqjh.settlement.monthly.preview",
  329. "settlement.monthly.plan",
  330. "settlement.monthly.preview",
  331. }:
  332. enterprise = str(arguments.get("enterprise") or "").strip() or _extract_settlement_enterprise(message)
  333. month = str(arguments.get("month") or "").strip() or _extract_month(message)
  334. return {**arguments, "enterprise": enterprise, "month": month}
  335. return arguments
  336. def _extract_product_keyword(message: str) -> str:
  337. text = message
  338. stop_words = [
  339. "帮我",
  340. "麻烦",
  341. "查询一下",
  342. "查询",
  343. "查看",
  344. "当前",
  345. "业务系统",
  346. "系统里面",
  347. "系统里",
  348. "后台",
  349. "我方",
  350. "我们的",
  351. "商品库",
  352. "商品表",
  353. "商品描述是什么",
  354. "商品描述",
  355. "描述是什么",
  356. "描述",
  357. "价格是多少",
  358. "价格",
  359. "定价",
  360. "是多少",
  361. "是什么",
  362. "呢",
  363. ]
  364. for token in stop_words:
  365. text = text.replace(token, " ")
  366. text = re.sub(r"[,。!?、;:,.!?;:]", " ", text)
  367. return " ".join(text.split())
  368. def _extract_settlement_enterprise(message: str) -> str:
  369. for enterprise in ("铜仁移动", "招商银行贵阳分行", "中数未来"):
  370. if enterprise in message:
  371. return enterprise
  372. return ""
  373. def _extract_month(message: str) -> str:
  374. explicit = re.search(r"(20\d{2})[-/年](0?[1-9]|1[0-2])", message)
  375. if explicit:
  376. return f"{explicit.group(1)}-{int(explicit.group(2)):02d}"
  377. month_only = re.search(r"(?<!\d)(0?[1-9]|1[0-2])\s*月份?", message)
  378. if month_only:
  379. return f"{datetime.now().year}-{int(month_only.group(1)):02d}"
  380. return ""
  381. def _resolve_tool_name(name: str, by_name: dict[str, McpToolDescriptor]) -> str | None:
  382. candidates = [
  383. name,
  384. name.removeprefix("smqjh.") if name.startswith("smqjh.") else f"smqjh.{name}",
  385. "smqjh.product.lookup.summary" if name == "product.lookup.summary" else "",
  386. "smqjh.order.count.query" if name == "order.count.query" else "",
  387. "smqjh.database.smart.query" if name == "database.smart.query" else "",
  388. "smqjh.database.readonly.query" if name == "database.readonly.query" else "",
  389. "smqjh.settlement.monthly.preview" if name == "settlement.monthly.preview" else "",
  390. "smqjh.settlement.monthly.plan" if name == "settlement.monthly.plan" else "",
  391. "smqjh.cloud.health" if name == "cloud.health" else "",
  392. ]
  393. return next((candidate for candidate in candidates if candidate and candidate in by_name), None)
  394. def _extract_tables(observations: list[ToolObservation]) -> list[AgentResultTable]:
  395. tables: list[AgentResultTable] = []
  396. for observation in observations:
  397. if not observation.get("ok"):
  398. continue
  399. tables.extend(_extract_tables_from_value(_tool_title(observation.get("name", "")), observation.get("result")))
  400. return tables[:6]
  401. def _extract_tables_from_value(title: str, value: Any) -> list[AgentResultTable]:
  402. tables: list[AgentResultTable] = []
  403. seen: set[int] = set()
  404. def visit(node: Any, current_title: str, depth: int) -> None:
  405. if depth > 5 or not isinstance(node, dict):
  406. return
  407. node_id = id(node)
  408. if node_id in seen:
  409. return
  410. seen.add(node_id)
  411. add_rows_table(current_title, node, "rows")
  412. add_rows_table(current_title, node, "comparisonRows")
  413. for key in ("data", "structuredContent", "result", "record"):
  414. if key in node:
  415. visit(node[key], current_title, depth + 1)
  416. def add_rows_table(current_title: str, record: dict[str, Any], rows_key: str) -> None:
  417. rows = record.get(rows_key)
  418. if not isinstance(rows, list) or not all(isinstance(row, dict) for row in rows):
  419. return
  420. row_records = [_normalize_row(row) for row in rows[:200]]
  421. column_key = "comparisonColumns" if rows_key == "comparisonRows" else "columns"
  422. provided_columns = record.get(column_key)
  423. if isinstance(provided_columns, list) and provided_columns:
  424. columns = [str(column) for column in provided_columns]
  425. else:
  426. columns = []
  427. for row in row_records:
  428. for key in row:
  429. if key not in columns:
  430. columns.append(key)
  431. if not columns:
  432. return
  433. tables.append(
  434. {
  435. "title": _build_table_title(current_title, record, rows_key),
  436. "columns": columns,
  437. "rows": row_records,
  438. }
  439. )
  440. visit(value, title, 0)
  441. return tables
  442. def _fallback_summary(observations: list[ToolObservation], tables: list[AgentResultTable]) -> str:
  443. ok_count = len([item for item in observations if item.get("ok")])
  444. failed = [item for item in observations if not item.get("ok")]
  445. lines = [f"已执行 {len(observations)} 个 MCP 工具,成功 {ok_count} 个。"]
  446. if tables:
  447. lines.append("结果已整理为下方表格。")
  448. if failed:
  449. lines.append("失败工具:" + ";".join(f"{item.get('name')}:{item.get('error') or '调用失败'}" for item in failed))
  450. evidence = _find_evidence(observations)
  451. if evidence:
  452. lines.append(f"依据:{evidence}")
  453. return "\n\n".join(lines)
  454. def _find_evidence(observations: list[ToolObservation]) -> str:
  455. for observation in observations:
  456. sql = _find_key_value(observation.get("result"), "executedSql", 0)
  457. if isinstance(sql, str) and sql.strip():
  458. return sql.strip()
  459. evidence = _find_key_value(observation.get("result"), "evidence", 0)
  460. if isinstance(evidence, str) and evidence.strip():
  461. return evidence.strip()
  462. return ""
  463. def _find_key_value(value: Any, key: str, depth: int) -> Any:
  464. if depth > 4 or not isinstance(value, dict):
  465. return None
  466. if key in value:
  467. return value[key]
  468. for child in value.values():
  469. result = _find_key_value(child, key, depth + 1)
  470. if result is not None:
  471. return result
  472. return None
  473. def _trace(phase: str, title: str, detail: str = "", data: Any = None) -> AgentTraceEvent:
  474. event: AgentTraceEvent = {
  475. "at": datetime.now(timezone.utc).isoformat(),
  476. "phase": phase, # type: ignore[typeddict-item]
  477. "title": title,
  478. }
  479. if detail:
  480. event["detail"] = detail
  481. if data is not None:
  482. event["data"] = data
  483. sink = _TRACE_SINK.get()
  484. if sink:
  485. try:
  486. sink(event)
  487. except Exception:
  488. pass
  489. return event
  490. def _append_trace(state: AgentState, *events: AgentTraceEvent) -> list[AgentTraceEvent]:
  491. return [*state.get("trace", []), *events]
  492. def _reason_text(record: dict[str, Any]) -> str:
  493. reason = record.get("reason")
  494. return reason.strip() if isinstance(reason, str) else ""
  495. def _summarize_result(value: Any) -> Any:
  496. record = _as_dict(value)
  497. if isinstance(record.get("rowCount"), int):
  498. return {"rowCount": record.get("rowCount"), "title": record.get("title"), "summary": record.get("summary")}
  499. if isinstance(record.get("rows"), list):
  500. return {"rows": len(record["rows"])}
  501. return value
  502. def _build_table_title(base_title: str, record: dict[str, Any], rows_key: str) -> str:
  503. title = str(record.get("title") or base_title).strip()
  504. count = f"({record['rowCount']} 行)" if "rowCount" in record else ""
  505. if rows_key == "comparisonRows":
  506. suffix = str(record.get("comparisonTitle") or "明细表").strip()
  507. return f"{title}:{suffix}"
  508. return f"{title}{count}"
  509. def _tool_title(name: str) -> str:
  510. titles = {
  511. "smqjh.config.get": "运行配置",
  512. "smqjh.cloud.health": "网关连通检查",
  513. "smqjh.schema.search": "业务表搜索结果",
  514. "smqjh.schema.getTable": "业务表说明",
  515. "smqjh.schema.businessRules": "业务规则",
  516. "smqjh.database.readonly.query": "数据库只读查询结果",
  517. "smqjh.database.smart.query": "智能数据库查询结果",
  518. "smqjh.order.count.query": "订单统计结果",
  519. "smqjh.product.lookup.summary": "商品资料查询结果",
  520. "smqjh.settlement.enterprise.list": "月结企业清单",
  521. "smqjh.settlement.monthly.plan": "企业月结计划",
  522. "smqjh.settlement.monthly.preview": "企业月结预览",
  523. }
  524. return titles.get(name, name)
  525. def _tool_signature(name: str, arguments: dict[str, Any]) -> str:
  526. return f"{name}:{json.dumps(arguments, ensure_ascii=False, sort_keys=True, default=str)}"
  527. def _as_dict(value: Any) -> dict[str, Any]:
  528. return value if isinstance(value, dict) else {}
  529. def _normalize_row(row: dict[str, Any]) -> dict[str, str]:
  530. return {str(key): "" if value is None else str(value) for key, value in row.items()}
  531. def _error_message(error: Exception) -> str:
  532. return str(error) or error.__class__.__name__