| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532 |
- from __future__ import annotations
- import json
- import time
- from datetime import datetime, timezone
- from typing import Any, Literal
- from langgraph.graph import END, START, StateGraph
- from .config import RuntimeConfig
- from .mcp_client import McpClient
- from .models import (
- AgentResultTable,
- AgentRunResponse,
- AgentState,
- AgentTraceEvent,
- ChatMessage,
- McpToolDescriptor,
- PlannedToolCall,
- ToolObservation,
- )
- Route = Literal["execute_tools", "summarize", "chat"]
- class SmqjhAgentGraph:
- def __init__(self, config: RuntimeConfig) -> None:
- self.config = config
- self.mcp = McpClient(config.mcp.url, config.mcp.token)
- self.graph = self._build_graph()
- def health(self) -> Any:
- return self.mcp.status()
- def tools(self) -> list[McpToolDescriptor]:
- return self._load_runnable_tools()
- def run(self, message: str, history: list[ChatMessage] | None = None) -> AgentRunResponse:
- state: AgentState = {
- "message": message,
- "history": history or [],
- "visitedSignatures": [],
- "steps": 0,
- "trace": [_trace("system", "收到用户请求", message)],
- }
- result = self.graph.invoke(state)
- final = result.get("final")
- if final:
- return final
- observations = result.get("observations", [])
- trace_events = result.get("trace", [])
- return {
- "content": "这次没有生成有效回答,请稍后重试。",
- "model": "none",
- "usedMcp": False,
- "steps": int(result.get("steps", 0)),
- "toolCalls": observations,
- "tables": _extract_tables(observations),
- "trace": trace_events,
- }
- def _build_graph(self):
- workflow = StateGraph(AgentState)
- workflow.add_node("load_tools", self._load_tools_node)
- workflow.add_node("plan", self._plan_node)
- workflow.add_node("execute_tools", self._execute_tools_node)
- workflow.add_node("summarize", self._summarize_node)
- workflow.add_node("chat", self._chat_node)
- workflow.add_edge(START, "load_tools")
- workflow.add_edge("load_tools", "plan")
- workflow.add_conditional_edges(
- "plan",
- self._route_after_plan,
- {
- "execute_tools": "execute_tools",
- "summarize": "summarize",
- "chat": "chat",
- },
- )
- workflow.add_edge("execute_tools", "plan")
- workflow.add_edge("summarize", END)
- workflow.add_edge("chat", END)
- return workflow.compile()
- def _load_tools_node(self, state: AgentState) -> AgentState:
- try:
- tools = self._load_runnable_tools()
- return {
- "tools": tools,
- "trace": _append_trace(state, _trace("system", "MCP 工具加载完成", f"可用工具 {len(tools)} 个")),
- }
- except Exception as error:
- return {
- "tools": [],
- "trace": _append_trace(state, _trace("error", "MCP 工具加载失败", _error_message(error))),
- }
- def _plan_node(self, state: AgentState) -> AgentState:
- if int(state.get("steps", 0)) >= self.config.max_steps:
- return {
- "plannedToolCalls": [],
- "trace": _append_trace(state, _trace("plan", "达到最大工具步数", f"maxSteps={self.config.max_steps}")),
- }
- try:
- result = self.mcp.call_tool(
- "smqjh.ai.tool.plan",
- {
- "request": self._assistant_request(state),
- "tools": state.get("tools", []),
- "observations": state.get("observations", []),
- },
- )
- record = _as_dict(result.get("structuredContent"))
- raw_calls = record.get("toolCalls")
- raw_call_items = raw_calls if isinstance(raw_calls, list) else []
- visited = set(state.get("visitedSignatures", []))
- planned = [
- call
- for call in (_normalize_tool_call(item, state.get("message", "")) for item in raw_call_items)
- if call and _tool_signature(call["name"], call["arguments"]) not in visited
- ][:3]
- return {
- "plannedToolCalls": planned,
- "trace": _append_trace(
- state,
- _trace(
- "plan",
- "DeepSeek 已规划工具调用" if planned else "DeepSeek 判断无需继续调用工具",
- _reason_text(record),
- {"toolCalls": planned},
- ),
- ),
- }
- except Exception as error:
- return {
- "plannedToolCalls": [],
- "trace": _append_trace(state, _trace("error", "DeepSeek 工具规划失败", _error_message(error))),
- }
- def _execute_tools_node(self, state: AgentState) -> AgentState:
- by_name = {tool.get("name", ""): tool for tool in state.get("tools", [])}
- observations: list[ToolObservation] = []
- visited = set(state.get("visitedSignatures", []))
- for call in state.get("plannedToolCalls", []):
- requested_name = call["name"]
- arguments = call["arguments"]
- resolved_name = _resolve_tool_name(requested_name, by_name)
- signature = _tool_signature(resolved_name or requested_name, arguments)
- visited.add(signature)
- if not resolved_name:
- observations.append(
- {
- "name": requested_name,
- "source": "mcp",
- "arguments": arguments,
- "ok": False,
- "error": "DeepSeek 选择了未注册 MCP 工具",
- }
- )
- continue
- started = time.perf_counter()
- try:
- result = self.mcp.call_tool(resolved_name, arguments)
- observations.append(
- {
- "name": resolved_name,
- "source": "mcp",
- "arguments": arguments,
- "ok": True,
- "result": result.get("structuredContent", result),
- "durationMs": int((time.perf_counter() - started) * 1000),
- }
- )
- except Exception as error:
- observations.append(
- {
- "name": resolved_name,
- "source": "mcp",
- "arguments": arguments,
- "ok": False,
- "error": _error_message(error),
- "durationMs": int((time.perf_counter() - started) * 1000),
- }
- )
- trace_events = state.get("trace", [])
- for item in observations:
- trace_events = [
- *trace_events,
- _trace(
- "tool" if item.get("ok") else "error",
- f"工具执行完成:{item['name']}" if item.get("ok") else f"工具执行失败:{item['name']}",
- f"{item.get('durationMs', 0)}ms" if item.get("ok") else item.get("error", ""),
- {"arguments": item.get("arguments"), "result": _summarize_result(item.get("result"))},
- ),
- ]
- return {
- "observations": [*state.get("observations", []), *observations],
- "plannedToolCalls": [],
- "visitedSignatures": list(visited),
- "steps": int(state.get("steps", 0)) + 1,
- "trace": trace_events,
- }
- def _summarize_node(self, state: AgentState) -> AgentState:
- observations = state.get("observations", [])
- tables = _extract_tables(observations)
- try:
- result = self.mcp.call_tool(
- "smqjh.ai.tool.summarize",
- {
- "request": self._assistant_request(state),
- "observations": observations,
- },
- )
- record = _as_dict(result.get("structuredContent"))
- content = str(record.get("content") or "").strip() or _fallback_summary(observations, tables)
- model = str(record.get("model") or "mcp-deepseek")
- trace_events = _append_trace(state, _trace("summary", "结果总结完成", f"{len(tables)} 个表格"))
- return {
- "final": {
- "content": content,
- "model": model,
- "usedMcp": True,
- "steps": int(state.get("steps", 0)),
- "toolCalls": observations,
- "tables": tables,
- "trace": trace_events,
- },
- "trace": trace_events,
- }
- except Exception as error:
- trace_events = _append_trace(state, _trace("error", "结果总结失败,已使用工具结果兜底", _error_message(error)))
- return {
- "final": {
- "content": _fallback_summary(observations, tables),
- "model": "tool-only",
- "usedMcp": True,
- "steps": int(state.get("steps", 0)),
- "toolCalls": observations,
- "tables": tables,
- "trace": trace_events,
- },
- "trace": trace_events,
- }
- def _chat_node(self, state: AgentState) -> AgentState:
- try:
- result = self.mcp.call_tool("smqjh.ai.chat", {"request": self._assistant_request(state)})
- record = _as_dict(result.get("structuredContent"))
- content = str(record.get("content") or "").strip() or "DeepSeek 没有返回有效内容。"
- model = str(record.get("model") or "mcp-deepseek")
- trace_events = _append_trace(state, _trace("chat", "普通对话完成", model))
- return {
- "final": {
- "content": content,
- "model": model,
- "usedMcp": True,
- "steps": int(state.get("steps", 0)),
- "toolCalls": [],
- "tables": [],
- "trace": trace_events,
- },
- "trace": trace_events,
- }
- except Exception as error:
- trace_events = _append_trace(state, _trace("error", "普通对话失败", _error_message(error)))
- return {
- "final": {
- "content": f"暂时无法完成回答:{_error_message(error)}",
- "model": "none",
- "usedMcp": False,
- "steps": int(state.get("steps", 0)),
- "toolCalls": [],
- "tables": [],
- "trace": trace_events,
- },
- "trace": trace_events,
- }
- def _route_after_plan(self, state: AgentState) -> Route:
- if state.get("plannedToolCalls") and int(state.get("steps", 0)) < self.config.max_steps:
- return "execute_tools"
- if state.get("observations"):
- return "summarize"
- return "chat"
- def _assistant_request(self, state: AgentState) -> dict[str, Any]:
- return {
- "message": state.get("message", ""),
- "history": state.get("history", []),
- "environmentName": self.config.environment_name,
- "baseUrl": self.config.base_url,
- "authenticated": True,
- "username": "web-agent",
- }
- def _load_runnable_tools(self) -> list[McpToolDescriptor]:
- return [tool for tool in self.mcp.list_tools() if not str(tool.get("name", "")).startswith("smqjh.ai.")]
- def _normalize_tool_call(value: Any, message: str) -> PlannedToolCall | None:
- record = _as_dict(value)
- name = str(record.get("name") or "").strip()
- if not name:
- return None
- return {
- "name": name,
- "arguments": _repair_tool_arguments(name, _as_dict(record.get("arguments")), message),
- }
- def _repair_tool_arguments(name: str, arguments: dict[str, Any], message: str) -> dict[str, Any]:
- if name == "smqjh.database.smart.query" and not isinstance(arguments.get("question"), str):
- return {**arguments, "question": message}
- if name == "smqjh.product.lookup.summary":
- current = str(arguments.get("productKeyword") or "").strip()
- return {**arguments, "productKeyword": current or _extract_product_keyword(message)}
- return arguments
- def _extract_product_keyword(message: str) -> str:
- text = message
- for token in [
- "帮我",
- "麻烦",
- "查询一下",
- "查一下",
- "查询",
- "查看",
- "当前",
- "业务系统",
- "系统里面",
- "系统里",
- "后台",
- "我方",
- "我们的",
- "商品库",
- "商品表",
- "商品描述是什么",
- "商品描述",
- "描述是什么",
- "描述",
- "价格是多少",
- "价格",
- "定价",
- "是多少",
- "是什么",
- "呢",
- ]:
- text = text.replace(token, " ")
- return " ".join(text.replace(",", " ").replace("。", " ").replace("?", " ").replace("?", " ").split())
- def _resolve_tool_name(name: str, by_name: dict[str, McpToolDescriptor]) -> str | None:
- candidates = [
- name,
- name.removeprefix("smqjh.") if name.startswith("smqjh.") else f"smqjh.{name}",
- "smqjh.product.lookup.summary" if name == "product.lookup.summary" else "",
- "smqjh.order.count.query" if name == "order.count.query" else "",
- "smqjh.database.smart.query" if name == "database.smart.query" else "",
- "smqjh.database.readonly.query" if name == "database.readonly.query" else "",
- "smqjh.cloud.health" if name == "cloud.health" else "",
- ]
- return next((candidate for candidate in candidates if candidate and candidate in by_name), None)
- def _extract_tables(observations: list[ToolObservation]) -> list[AgentResultTable]:
- tables: list[AgentResultTable] = []
- for observation in observations:
- if not observation.get("ok"):
- continue
- tables.extend(_extract_tables_from_value(_tool_title(observation.get("name", "")), observation.get("result")))
- return tables[:6]
- def _extract_tables_from_value(title: str, value: Any) -> list[AgentResultTable]:
- tables: list[AgentResultTable] = []
- seen: set[int] = set()
- def visit(node: Any, current_title: str, depth: int) -> None:
- if depth > 5 or not isinstance(node, dict):
- return
- node_id = id(node)
- if node_id in seen:
- return
- seen.add(node_id)
- add_rows_table(current_title, node, "rows")
- add_rows_table(current_title, node, "comparisonRows")
- for key in ("data", "structuredContent", "result", "record"):
- if key in node:
- visit(node[key], current_title, depth + 1)
- def add_rows_table(current_title: str, record: dict[str, Any], rows_key: str) -> None:
- rows = record.get(rows_key)
- if not isinstance(rows, list) or not all(isinstance(row, dict) for row in rows):
- return
- row_records = [_normalize_row(row) for row in rows[:200]]
- provided_columns = record.get("columns")
- if isinstance(provided_columns, list) and provided_columns:
- columns = [str(column) for column in provided_columns]
- else:
- columns = []
- for row in row_records:
- for key in row:
- if key not in columns:
- columns.append(key)
- if not columns:
- return
- tables.append(
- {
- "title": _build_table_title(current_title, record, rows_key),
- "columns": columns,
- "rows": row_records,
- }
- )
- visit(value, title, 0)
- return tables
- def _fallback_summary(observations: list[ToolObservation], tables: list[AgentResultTable]) -> str:
- ok_count = len([item for item in observations if item.get("ok")])
- failed = [item for item in observations if not item.get("ok")]
- lines = [f"已执行 {len(observations)} 个 MCP 工具,成功 {ok_count} 个。"]
- if tables:
- lines.append("结果已整理为下方表格。")
- if failed:
- lines.append("失败工具:" + ";".join(f"{item.get('name')}:{item.get('error') or '调用失败'}" for item in failed))
- evidence = _find_evidence(observations)
- if evidence:
- lines.append(f"依据:{evidence}")
- return "\n\n".join(lines)
- def _find_evidence(observations: list[ToolObservation]) -> str:
- for observation in observations:
- sql = _find_key_value(observation.get("result"), "executedSql", 0)
- if isinstance(sql, str) and sql.strip():
- return sql.strip()
- evidence = _find_key_value(observation.get("result"), "evidence", 0)
- if isinstance(evidence, str) and evidence.strip():
- return evidence.strip()
- return ""
- def _find_key_value(value: Any, key: str, depth: int) -> Any:
- if depth > 4 or not isinstance(value, dict):
- return None
- if key in value:
- return value[key]
- for child in value.values():
- result = _find_key_value(child, key, depth + 1)
- if result is not None:
- return result
- return None
- def _trace(phase: str, title: str, detail: str = "", data: Any = None) -> AgentTraceEvent:
- event: AgentTraceEvent = {
- "at": datetime.now(timezone.utc).isoformat(),
- "phase": phase, # type: ignore[typeddict-item]
- "title": title,
- }
- if detail:
- event["detail"] = detail
- if data is not None:
- event["data"] = data
- return event
- def _append_trace(state: AgentState, event: AgentTraceEvent) -> list[AgentTraceEvent]:
- return [*state.get("trace", []), event]
- def _reason_text(record: dict[str, Any]) -> str:
- reason = record.get("reason")
- return reason.strip() if isinstance(reason, str) else ""
- def _summarize_result(value: Any) -> Any:
- record = _as_dict(value)
- if isinstance(record.get("rowCount"), int):
- return {"rowCount": record.get("rowCount"), "title": record.get("title"), "summary": record.get("summary")}
- if isinstance(record.get("rows"), list):
- return {"rows": len(record["rows"])}
- return value
- def _build_table_title(base_title: str, record: dict[str, Any], rows_key: str) -> str:
- title = str(record.get("title") or base_title).strip()
- count = f"({record['rowCount']} 行)" if "rowCount" in record else ""
- return f"{title}:价格对比" if rows_key == "comparisonRows" else f"{title}{count}"
- def _tool_title(name: str) -> str:
- titles = {
- "smqjh.config.get": "运行配置",
- "smqjh.cloud.health": "网关连通检查",
- "smqjh.schema.search": "业务表搜索结果",
- "smqjh.schema.getTable": "业务表说明",
- "smqjh.schema.businessRules": "业务规则",
- "smqjh.database.readonly.query": "数据库只读查询结果",
- "smqjh.database.smart.query": "智能数据库查询结果",
- "smqjh.order.count.query": "订单统计结果",
- "smqjh.product.lookup.summary": "商品资料查询结果",
- "smqjh.settlement.enterprise.list": "月结企业清单",
- "smqjh.settlement.monthly.plan": "企业月结计划",
- }
- return titles.get(name, name)
- def _tool_signature(name: str, arguments: dict[str, Any]) -> str:
- return f"{name}:{json.dumps(arguments, ensure_ascii=False, sort_keys=True, default=str)}"
- def _as_dict(value: Any) -> dict[str, Any]:
- return value if isinstance(value, dict) else {}
- def _normalize_row(row: dict[str, Any]) -> dict[str, str]:
- return {str(key): "" if value is None else str(value) for key, value in row.items()}
- def _error_message(error: Exception) -> str:
- return str(error) or error.__class__.__name__
|