Quellcode durchsuchen

improve topo tools description

Guangyu vor 3 Wochen
Ursprung
Commit
522b113edf
3 geänderte Dateien mit 80 neuen und 12 gelöschten Zeilen
  1. 19 1
      README.md
  2. 24 6
      instrument_config_mcp/server.py
  3. 37 5
      instrument_config_mcp/topology_cache.py

+ 19 - 1
README.md

@@ -403,11 +403,19 @@ Schema 说明:
 - `root_shape`
 - `refreshed_at`
 
-### `topology_get_node(project_key, topology_id, node_id, include_siblings=True, include_children=True)`
+### `topology_get_node(project_key, topology_id, node_id='root', include_siblings=True, include_children=True)`
 
 用途:
 
 - 获取一个节点及其直接邻域
+- 如果 `node_id='root'`,会自动解析该拓扑的根节点
+
+输入约束:
+
+- `node_id` 必须是该拓扑缓存中的真实 `node_id`
+- `node_id` 不是节点名称
+- `node_id='root'` 是便捷模式,不要求缓存里真的存在名为 `root` 的节点
+- 如果不知道真实 `node_id`,可先用 `topology.find_context` 从命中结果里的 `self.node_id` 获取
 
 输出结构:
 
@@ -423,6 +431,12 @@ Schema 说明:
 
 - 按实体快速反查命中的拓扑节点上下文
 
+输入约束:
+
+- `entity_type` 只允许 `meter` 或 `device`
+- 不接受中文值,如 `仪表`、`设备`
+- 不接受复数或其他业务类型,如 `meters`、`devices`、`point`、`system`
+
 输出结构:
 
 - `query`
@@ -460,6 +474,10 @@ Schema 说明:
 - 直接子节点
 - 同父兄弟节点
 
+其中:
+
+- 当 `node_id='root'` 时,会先自动解析当前拓扑的根节点,再返回该根节点的直接邻域
+
 它不做深层遍历。
 
 ### 实体上下文查询

+ 24 - 6
instrument_config_mcp/server.py

@@ -1,9 +1,11 @@
 from __future__ import annotations
 
 import os
+from typing import Annotated
 from typing import Any
 
 from fastmcp import FastMCP
+from pydantic import Field
 from starlette.requests import Request
 from starlette.responses import JSONResponse, Response
 
@@ -227,7 +229,7 @@ def search_points(
 
 @mcp.tool(name="topology.group_list")
 def topology_group_list(project_key: str) -> dict[str, Any]:
-    """List cached topology groups as a tree."""
+    """List topology groups to help locate device and meter relationships."""
     return list_topology_groups(project_key)
 
 
@@ -235,7 +237,7 @@ def topology_group_list(project_key: str) -> dict[str, Any]:
 def topology_list(
     project_key: str, group_id: int | None = None, object_type_code: int | None = None
 ) -> dict[str, Any]:
-    """List cached topologies, optionally filtered by group or object type."""
+    """List topologies for checking device and meter hierarchy relationships."""
     return list_topologies(
         project_key, group_id=group_id, object_type_code=object_type_code
     )
@@ -245,11 +247,19 @@ def topology_list(
 def topology_get_node(
     project_key: str,
     topology_id: int,
-    node_id: str,
+    node_id: Annotated[
+        str,
+        Field(
+            description=(
+                "Node ID in the target topology. Use 'root' to resolve the "
+                "topology root node automatically."
+            )
+        ),
+    ] = "root",
     include_siblings: bool = True,
     include_children: bool = True,
 ) -> dict[str, Any]:
-    """Get one cached topology node with its immediate neighborhood."""
+    """Get one topology node and its direct parent-child relationships."""
     return get_topology_node(
         project_key,
         topology_id,
@@ -262,14 +272,22 @@ def topology_get_node(
 @mcp.tool(name="topology.find_context")
 def topology_find_context(
     project_key: str,
-    entity_type: str,
+    entity_type: Annotated[
+        str,
+        Field(
+            description=(
+                "Entity type for topology relationship lookup. Only 'meter' or "
+                "'device' are accepted."
+            )
+        ),
+    ],
     entity_id: int,
     topology_id: int | None = None,
     include_siblings: bool = True,
     ancestor_depth: int = 5,
     descendant_depth: int = 2,
 ) -> dict[str, Any]:
-    """Find cached topology context for a device or meter."""
+    """Find topology context for a meter or device, including hierarchy relationships."""
     return find_topology_context(
         project_key,
         entity_type,

+ 37 - 5
instrument_config_mcp/topology_cache.py

@@ -901,6 +901,33 @@ def _get_node_or_error(
     return node_row
 
 
+def _resolve_root_node_id(
+    node_map: dict[str, TopologyNode], parents_by_node: dict[str, list[str]]
+) -> str:
+    root_ids = [node_id for node_id in node_map if not parents_by_node.get(node_id)]
+    if not root_ids:
+        raise ValueError("root node not found in cache")
+    root_ids.sort(
+        key=lambda item: (
+            node_map[item].level if node_map[item].level is not None else 10**9,
+            node_map[item].path_text or node_map[item].node_name,
+            item,
+        )
+    )
+    return root_ids[0]
+
+
+def _normalize_requested_node_id(
+    raw_node_id: str,
+    node_map: dict[str, TopologyNode],
+    parents_by_node: dict[str, list[str]],
+) -> str:
+    normalized = _text(raw_node_id)
+    if normalized.lower() != "root":
+        return normalized
+    return _resolve_root_node_id(node_map, parents_by_node)
+
+
 def _node_payload(node_row: TopologyNode) -> dict[str, Any]:
     return {
         "node_id": node_row.node_id,
@@ -1066,7 +1093,7 @@ def _topology_metadata_payload(
 def get_topology_node(
     project_key: str,
     topology_id: int,
-    node_id: str,
+    node_id: str = "root",
     *,
     include_siblings: bool = True,
     include_children: bool = True,
@@ -1084,20 +1111,25 @@ def get_topology_node(
         group_rows = _load_group_rows(session, project_key)
         group_path_map = _group_path_map(group_rows)
         registry_row = _get_registry_or_error(session, project_key, topology_id)
-        node_row = _get_node_or_error(session, project_key, topology_id, node_id)
         node_map = _load_node_map(session, project_key, topology_id)
         parents_by_node, children_by_node = _load_adjacency(
             session, project_key, topology_id
         )
+        resolved_node_id = _normalize_requested_node_id(
+            node_id, node_map, parents_by_node
+        )
+        node_row = _get_node_or_error(
+            session, project_key, topology_id, resolved_node_id
+        )
 
-    parent_ids = parents_by_node.get(node_id, [])
-    child_ids = children_by_node.get(node_id, []) if include_children else []
+    parent_ids = parents_by_node.get(resolved_node_id, [])
+    child_ids = children_by_node.get(resolved_node_id, []) if include_children else []
     sibling_ids: list[str] = []
     if include_siblings and len(parent_ids) == 1:
         sibling_ids = [
             candidate
             for candidate in children_by_node.get(parent_ids[0], [])
-            if candidate != node_id
+            if candidate != resolved_node_id
         ]
 
     return {