from __future__ import annotations from collections import defaultdict, deque from datetime import datetime, timezone from typing import Any from sqlalchemy import Index, Integer, String, Text, UniqueConstraint, delete, select from sqlalchemy.orm import Mapped, Session, mapped_column from .config_api import list_topologies_with_group as api_list_topologies_with_group from .config_api import get_topology as api_get_topology from .db import Base, sql_engine def _utc_now_iso() -> str: return datetime.now(timezone.utc).isoformat() def _safe_int(raw_value: Any) -> int | None: if raw_value is None: return None try: return int(str(raw_value).strip()) except Exception: return None def _text(raw_value: Any) -> str: return str(raw_value or "").strip() def _bool_as_int(raw_value: Any) -> int: return 1 if bool(raw_value) else 0 def _normalize_entity_type(entity_type: str) -> str: normalized = _text(entity_type).lower() if normalized not in {"meter", "device"}: raise ValueError("entity_type must be 'meter' or 'device'") return normalized class TopologyGroup(Base): __tablename__ = "topology_group" __table_args__ = ( UniqueConstraint( "project_key", "group_id", name="uq_topology_group_project_group" ), Index("ix_topology_group_project_parent", "project_key", "parent_group_id"), ) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) project_key: Mapped[str] = mapped_column(String(128), nullable=False) group_id: Mapped[int] = mapped_column(Integer, nullable=False) group_name: Mapped[str] = mapped_column(String(255), nullable=False) parent_group_id: Mapped[int | None] = mapped_column(Integer, nullable=True) group_path_text: Mapped[str] = mapped_column(Text, nullable=False) level: Mapped[int] = mapped_column(Integer, nullable=False) sort_index: Mapped[int] = mapped_column(Integer, nullable=False) refreshed_at: Mapped[str] = mapped_column(String(64), nullable=False) is_active: Mapped[int] = mapped_column(Integer, nullable=False, default=1) class TopologyRegistry(Base): __tablename__ = "topology_registry" __table_args__ = ( UniqueConstraint( "project_key", "topology_id", name="uq_topology_registry_project_topology" ), Index("ix_topology_registry_project_group", "project_key", "group_id"), ) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) project_key: Mapped[str] = mapped_column(String(128), nullable=False) topology_id: Mapped[int] = mapped_column(Integer, nullable=False) topology_name: Mapped[str] = mapped_column(String(255), nullable=False) topology_type: Mapped[int] = mapped_column(Integer, nullable=False) object_type_code: Mapped[int | None] = mapped_column(Integer, nullable=True) group_id: Mapped[int | None] = mapped_column(Integer, nullable=True) root_shape: Mapped[str] = mapped_column(String(32), nullable=False) source_updated_time: Mapped[str] = mapped_column( String(64), nullable=False, default="" ) refreshed_at: Mapped[str] = mapped_column(String(64), nullable=False) is_active: Mapped[int] = mapped_column(Integer, nullable=False, default=1) class TopologyNode(Base): __tablename__ = "topology_node" __table_args__ = ( UniqueConstraint( "project_key", "topology_id", "node_id", name="uq_topology_node_project_topology_node", ), Index( "ix_topology_node_project_topology_parent", "project_key", "topology_id", "parent_node_id", ), Index( "ix_topology_node_project_topology_refer", "project_key", "topology_id", "refer_id", ), ) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) project_key: Mapped[str] = mapped_column(String(128), nullable=False) topology_id: Mapped[int] = mapped_column(Integer, nullable=False) node_id: Mapped[str] = mapped_column(Text, nullable=False) node_name: Mapped[str] = mapped_column(Text, nullable=False) parent_node_id: Mapped[str | None] = mapped_column(Text, nullable=True) level: Mapped[int | None] = mapped_column(Integer, nullable=True) node_type_code: Mapped[int | None] = mapped_column(Integer, nullable=True) refer_id: Mapped[int | None] = mapped_column(Integer, nullable=True) refer_level: Mapped[int | None] = mapped_column(Integer, nullable=True) is_virtual: Mapped[int] = mapped_column(Integer, nullable=False, default=0) path_text: Mapped[str] = mapped_column(Text, nullable=False, default="") child_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) sort_index: Mapped[int | None] = mapped_column(Integer, nullable=True) class TopologyEdge(Base): __tablename__ = "topology_edge" __table_args__ = ( UniqueConstraint( "project_key", "topology_id", "source_node_id", "target_node_id", name="uq_topology_edge_project_topology_nodes", ), Index( "ix_topology_edge_project_topology_source", "project_key", "topology_id", "source_node_id", ), Index( "ix_topology_edge_project_topology_target", "project_key", "topology_id", "target_node_id", ), ) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) project_key: Mapped[str] = mapped_column(String(128), nullable=False) topology_id: Mapped[int] = mapped_column(Integer, nullable=False) source_node_id: Mapped[str] = mapped_column(Text, nullable=False) target_node_id: Mapped[str] = mapped_column(Text, nullable=False) sort_index: Mapped[int] = mapped_column(Integer, nullable=False, default=0) class TopologyEntityIndex(Base): __tablename__ = "topology_entity_index" __table_args__ = ( UniqueConstraint( "project_key", "entity_type", "entity_id", "topology_id", "node_id", name="uq_topology_entity_index_project_entity_topology_node", ), Index( "ix_topology_entity_index_project_entity", "project_key", "entity_type", "entity_id", ), Index( "ix_topology_entity_index_project_topology_node", "project_key", "topology_id", "node_id", ), ) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) project_key: Mapped[str] = mapped_column(String(128), nullable=False) entity_type: Mapped[str] = mapped_column(String(16), nullable=False) entity_id: Mapped[int] = mapped_column(Integer, nullable=False) topology_id: Mapped[int] = mapped_column(Integer, nullable=False) node_id: Mapped[str] = mapped_column(Text, nullable=False) depth: Mapped[int | None] = mapped_column(Integer, nullable=True) def ensure_topology_cache_tables() -> None: Base.metadata.create_all(sql_engine()) def _collect_group_and_topology_refs( project_key: str, items: list[Any], *, refreshed_at: str, parent_group_id: int | None = None, parent_group_path: tuple[str, ...] = (), ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: group_rows: list[dict[str, Any]] = [] topology_refs: list[dict[str, Any]] = [] for sort_index, item in enumerate(items, start=1): if not isinstance(item, dict): continue item_id = _safe_int(item.get("id")) item_name = _text(item.get("name")) or str(item_id or "") item_type = _safe_int(item.get("type")) children = ( item.get("children") if isinstance(item.get("children"), list) else [] ) if item_id is None: continue if item_type == 1: group_path = (*parent_group_path, item_name) group_rows.append( { "project_key": project_key, "group_id": item_id, "group_name": item_name, "parent_group_id": parent_group_id, "group_path_text": " / ".join(group_path), "level": len(group_path), "sort_index": sort_index, "refreshed_at": refreshed_at, "is_active": 1, } ) nested_group_rows, nested_topology_refs = _collect_group_and_topology_refs( project_key, children, refreshed_at=refreshed_at, parent_group_id=item_id, parent_group_path=group_path, ) group_rows.extend(nested_group_rows) topology_refs.extend(nested_topology_refs) continue topology_refs.append( { "topology_id": item_id, "topology_name": item_name, "group_id": parent_group_id, "sort_index": sort_index, } ) return group_rows, topology_refs def _as_int_list(raw_value: Any) -> list[int]: if not isinstance(raw_value, list): return [] result: list[int] = [] for item in raw_value: value = _safe_int(item) if value is None: continue result.append(value) return result def _record_deepest_entities( entity_best: dict[tuple[str, int], tuple[int, list[str]]], entity_node_ids: dict[tuple[str, int], list[str]], entity_type: str, entity_ids: list[int], *, depth: int, node_id: str, ) -> None: for entity_id in entity_ids: key = (entity_type, entity_id) existing = entity_best.get(key) if existing is None or depth > existing[0]: entity_best[key] = (depth, [node_id]) entity_node_ids[key] = [node_id] continue if depth == existing[0] and node_id not in entity_node_ids[key]: entity_node_ids[key].append(node_id) def _parse_tree_topology( project_key: str, topology_id: int, diagram: list[Any], ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], list[dict[str, Any]]]: node_rows: list[dict[str, Any]] = [] edge_rows: list[dict[str, Any]] = [] entity_best: dict[tuple[str, int], tuple[int, list[str]]] = {} entity_node_ids: dict[tuple[str, int], list[str]] = {} def visit( node: dict[str, Any], parent_node_id: str | None, path_names: tuple[str, ...], sort_index: int, ) -> None: node_id = _text(node.get("id")) if not node_id: return node_name = _text(node.get("name")) or node_id children = ( node.get("children") if isinstance(node.get("children"), list) else [] ) level = _safe_int(node.get("level")) or (len(path_names) + 1) path_text = " / ".join((*path_names, node_name)) effective_parent_node_id = parent_node_id or ( _text(node.get("parent_id")) or None ) node_rows.append( { "project_key": project_key, "topology_id": topology_id, "node_id": node_id, "node_name": node_name, "parent_node_id": effective_parent_node_id, "level": level, "node_type_code": _safe_int(node.get("type")), "refer_id": _safe_int(node.get("refer_id")), "refer_level": _safe_int(node.get("refer_level")), "is_virtual": _bool_as_int(node.get("is_virtual")), "path_text": path_text, "child_count": len(children), "sort_index": sort_index, } ) if effective_parent_node_id: edge_rows.append( { "project_key": project_key, "topology_id": topology_id, "source_node_id": effective_parent_node_id, "target_node_id": node_id, "sort_index": sort_index, } ) _record_deepest_entities( entity_best, entity_node_ids, "meter", _as_int_list(node.get("meter_list")), depth=level, node_id=node_id, ) _record_deepest_entities( entity_best, entity_node_ids, "device", _as_int_list(node.get("device_list")), depth=level, node_id=node_id, ) for child_sort_index, child in enumerate(children, start=1): if not isinstance(child, dict): continue visit(child, node_id, (*path_names, node_name), child_sort_index) for root_sort_index, root in enumerate(diagram, start=1): if not isinstance(root, dict): continue visit(root, None, (), root_sort_index) entity_rows: list[dict[str, Any]] = [] for (entity_type, entity_id), (depth, _) in entity_best.items(): for node_id in entity_node_ids[(entity_type, entity_id)]: entity_rows.append( { "project_key": project_key, "entity_type": entity_type, "entity_id": entity_id, "topology_id": topology_id, "node_id": node_id, "depth": depth, } ) return node_rows, edge_rows, entity_rows def _build_graph_node_context( nodes: list[dict[str, Any]], edges: list[dict[str, Any]], ) -> tuple[dict[str, int], dict[str, str | None], dict[str, str], dict[str, int]]: incoming: dict[str, list[str]] = defaultdict(list) outgoing: dict[str, list[str]] = defaultdict(list) for edge in edges: source_node_id = _text(edge.get("source")) target_node_id = _text(edge.get("target")) if not source_node_id or not target_node_id: continue outgoing[source_node_id].append(target_node_id) incoming[target_node_id].append(source_node_id) node_ids = [_text(node.get("id")) for node in nodes if _text(node.get("id"))] roots = [node_id for node_id in node_ids if not incoming.get(node_id)] if not roots: roots = node_ids[:1] level_map: dict[str, int] = {} path_map: dict[str, str] = {} parent_map: dict[str, str | None] = {node_id: None for node_id in node_ids} child_count_map: dict[str, int] = { node_id: len(outgoing.get(node_id, [])) for node_id in node_ids } node_name_map = { _text(node.get("id")): _text(node.get("name")) or _text(node.get("id")) for node in nodes } queue: deque[tuple[str, int, str]] = deque() for root_node_id in roots: root_name = node_name_map.get(root_node_id) or root_node_id queue.append((root_node_id, 1, root_name)) while queue: node_id, depth, path_text = queue.popleft() previous = level_map.get(node_id) if previous is not None and previous >= depth: continue level_map[node_id] = depth path_map[node_id] = path_text for child_node_id in outgoing.get(node_id, []): if parent_map.get(child_node_id) is None: parent_map[child_node_id] = node_id child_name = node_name_map.get(child_node_id) or child_node_id queue.append((child_node_id, depth + 1, f"{path_text} / {child_name}")) for node in nodes: node_id = _text(node.get("id")) if not node_id: continue if node_id in level_map: continue fallback_level = _safe_int(node.get("level")) or 1 node_name = _text(node.get("name")) or node_id level_map[node_id] = fallback_level path_map[node_id] = node_name return level_map, parent_map, path_map, child_count_map def _parse_graph_topology( project_key: str, topology_id: int, diagram: dict[str, Any], ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], list[dict[str, Any]]]: raw_nodes = [item for item in diagram.get("nodes", []) if isinstance(item, dict)] raw_edges = [item for item in diagram.get("edges", []) if isinstance(item, dict)] level_map, parent_map, path_map, child_count_map = _build_graph_node_context( raw_nodes, raw_edges ) node_rows: list[dict[str, Any]] = [] edge_rows: list[dict[str, Any]] = [] entity_best: dict[tuple[str, int], tuple[int, list[str]]] = {} entity_node_ids: dict[tuple[str, int], list[str]] = {} for sort_index, node in enumerate(raw_nodes, start=1): node_id = _text(node.get("id")) if not node_id: continue level = level_map.get(node_id) or _safe_int(node.get("level")) or 1 node_rows.append( { "project_key": project_key, "topology_id": topology_id, "node_id": node_id, "node_name": _text(node.get("name")) or node_id, "parent_node_id": parent_map.get(node_id), "level": level, "node_type_code": _safe_int(node.get("type")), "refer_id": _safe_int(node.get("refer_id")), "refer_level": _safe_int(node.get("refer_level")), "is_virtual": _bool_as_int(node.get("is_virtual")), "path_text": path_map.get(node_id, _text(node.get("name")) or node_id), "child_count": child_count_map.get(node_id, 0), "sort_index": sort_index, } ) _record_deepest_entities( entity_best, entity_node_ids, "meter", _as_int_list(node.get("meter_list")), depth=level, node_id=node_id, ) _record_deepest_entities( entity_best, entity_node_ids, "device", _as_int_list(node.get("device_list")), depth=level, node_id=node_id, ) for sort_index, edge in enumerate(raw_edges, start=1): source_node_id = _text(edge.get("source")) target_node_id = _text(edge.get("target")) if not source_node_id or not target_node_id: continue edge_rows.append( { "project_key": project_key, "topology_id": topology_id, "source_node_id": source_node_id, "target_node_id": target_node_id, "sort_index": sort_index, } ) entity_rows: list[dict[str, Any]] = [] for (entity_type, entity_id), (depth, _) in entity_best.items(): for node_id in entity_node_ids[(entity_type, entity_id)]: entity_rows.append( { "project_key": project_key, "entity_type": entity_type, "entity_id": entity_id, "topology_id": topology_id, "node_id": node_id, "depth": depth, } ) return node_rows, edge_rows, entity_rows def refresh_topology_cache( project_key: str, topology_ids: list[int] | None = None ) -> dict[str, Any]: project_key = _text(project_key) if not project_key: raise ValueError("project_key is required") ensure_topology_cache_tables() refreshed_at = _utc_now_iso() list_payload = api_list_topologies_with_group(project_key, group_ids=[]) raw_items = list_payload.get("data") if not isinstance(raw_items, list): raise ValueError("topology list returned invalid data") group_rows, topology_refs = _collect_group_and_topology_refs( project_key, raw_items, refreshed_at=refreshed_at ) requested_topology_ids = {item for item in topology_ids or []} available_topology_map = {item["topology_id"]: item for item in topology_refs} if requested_topology_ids: missing = sorted(requested_topology_ids - set(available_topology_map)) if missing: raise ValueError( f"topology_id not found in upstream topology list: {missing}" ) selected_topology_refs = [ available_topology_map[item] for item in topology_ids or [] ] else: selected_topology_refs = topology_refs registry_rows: list[dict[str, Any]] = [] node_rows: list[dict[str, Any]] = [] edge_rows: list[dict[str, Any]] = [] entity_rows: list[dict[str, Any]] = [] for topology_ref in selected_topology_refs: topology_id = topology_ref["topology_id"] detail_payload = api_get_topology(project_key, topology_id) detail_data = detail_payload.get("data") if not isinstance(detail_data, dict): raise ValueError( f"topology get returned invalid data for topology_id={topology_id}" ) diagram = detail_data.get("diagram") if isinstance(diagram, list): root_shape = "tree" topology_node_rows, topology_edge_rows, topology_entity_rows = ( _parse_tree_topology(project_key, topology_id, diagram) ) elif isinstance(diagram, dict): root_shape = "graph" topology_node_rows, topology_edge_rows, topology_entity_rows = ( _parse_graph_topology(project_key, topology_id, diagram) ) else: root_shape = "tree" topology_node_rows, topology_edge_rows, topology_entity_rows = ([], [], []) registry_rows.append( { "project_key": project_key, "topology_id": topology_id, "topology_name": _text(detail_data.get("name")) or topology_ref["topology_name"], "topology_type": _safe_int(detail_data.get("type")) or 0, "object_type_code": _safe_int(detail_data.get("object")), "group_id": _safe_int(detail_data.get("group_id")) or topology_ref.get("group_id"), "root_shape": root_shape, "source_updated_time": _text(detail_data.get("updated_time")), "refreshed_at": refreshed_at, "is_active": 1, } ) node_rows.extend(topology_node_rows) edge_rows.extend(topology_edge_rows) entity_rows.extend(topology_entity_rows) with Session(sql_engine()) as session: if requested_topology_ids: session.execute( delete(TopologyGroup).where(TopologyGroup.project_key == project_key) ) session.add_all(TopologyGroup(**row) for row in group_rows) session.execute( delete(TopologyEntityIndex).where( TopologyEntityIndex.project_key == project_key, TopologyEntityIndex.topology_id.in_(requested_topology_ids), ) ) session.execute( delete(TopologyEdge).where( TopologyEdge.project_key == project_key, TopologyEdge.topology_id.in_(requested_topology_ids), ) ) session.execute( delete(TopologyNode).where( TopologyNode.project_key == project_key, TopologyNode.topology_id.in_(requested_topology_ids), ) ) session.execute( delete(TopologyRegistry).where( TopologyRegistry.project_key == project_key, TopologyRegistry.topology_id.in_(requested_topology_ids), ) ) else: session.execute( delete(TopologyEntityIndex).where( TopologyEntityIndex.project_key == project_key ) ) session.execute( delete(TopologyEdge).where(TopologyEdge.project_key == project_key) ) session.execute( delete(TopologyNode).where(TopologyNode.project_key == project_key) ) session.execute( delete(TopologyRegistry).where( TopologyRegistry.project_key == project_key ) ) session.execute( delete(TopologyGroup).where(TopologyGroup.project_key == project_key) ) session.add_all(TopologyGroup(**row) for row in group_rows) session.add_all(TopologyRegistry(**row) for row in registry_rows) session.add_all(TopologyNode(**row) for row in node_rows) session.add_all(TopologyEdge(**row) for row in edge_rows) session.add_all(TopologyEntityIndex(**row) for row in entity_rows) session.commit() result = { "project_key": project_key, "refreshed_group_count": len(group_rows), "refreshed_topology_count": len(registry_rows), "refreshed_node_count": len(node_rows), "refreshed_edge_count": len(edge_rows), "refreshed_entity_index_count": len(entity_rows), "topology_ids": [row["topology_id"] for row in registry_rows], "refreshed_at": refreshed_at, } if not registry_rows: result["mcp_note"] = ( f"Project '{project_key}' currently has no topology data in upstream config." ) return result def _load_group_rows(session: Session, project_key: str) -> list[TopologyGroup]: return list( session.scalars( select(TopologyGroup) .where(TopologyGroup.project_key == project_key) .order_by( TopologyGroup.level.asc(), TopologyGroup.sort_index.asc(), TopologyGroup.group_id.asc(), ) ) ) def _load_registry_rows(session: Session, project_key: str) -> list[TopologyRegistry]: return list( session.scalars( select(TopologyRegistry) .where(TopologyRegistry.project_key == project_key) .order_by( TopologyRegistry.topology_name.asc(), TopologyRegistry.topology_id.asc() ) ) ) def _has_topology_cache(session: Session, project_key: str) -> bool: registry_rows = _load_registry_rows(session, project_key) return bool(registry_rows) def list_topology_groups(project_key: str) -> dict[str, Any]: project_key = _text(project_key) if not project_key: raise ValueError("project_key is required") ensure_topology_cache_tables() with Session(sql_engine()) as session: if not _has_topology_cache(session, project_key): return { "project_key": project_key, "groups": [], "total": 0, "mcp_note": ( f"Project '{project_key}' has no cached topologies. Refresh it via " f"GET /topology/cache/refresh?project_key={project_key}. If the refresh " "already succeeded, the upstream project likely has no topology data." ), } group_rows = _load_group_rows(session, project_key) children_by_parent: dict[int | None, list[dict[str, Any]]] = defaultdict(list) node_by_group_id: dict[int, dict[str, Any]] = {} for group_row in group_rows: payload = { "group_id": group_row.group_id, "group_name": group_row.group_name, "parent_group_id": group_row.parent_group_id, "level": group_row.level, "group_path_text": group_row.group_path_text, "children": [], } node_by_group_id[group_row.group_id] = payload children_by_parent[group_row.parent_group_id].append(payload) for group_payload in node_by_group_id.values(): group_payload["children"] = children_by_parent.get( group_payload["group_id"], [] ) return { "project_key": project_key, "groups": children_by_parent.get(None, []), "total": len(group_rows), } def _group_path_map(group_rows: list[TopologyGroup]) -> dict[int, str]: return {group_row.group_id: group_row.group_path_text for group_row in group_rows} def _collect_descendant_group_ids( group_rows: list[TopologyGroup], group_id: int ) -> set[int]: group_children: dict[int | None, list[int]] = defaultdict(list) for row in group_rows: group_children[row.parent_group_id].append(row.group_id) result: set[int] = set() queue: deque[int] = deque([group_id]) while queue: current_group_id = queue.popleft() if current_group_id in result: continue result.add(current_group_id) queue.extend(group_children.get(current_group_id, [])) return result def list_topologies( project_key: str, *, group_id: int | None = None, object_type_code: int | None = None, ) -> dict[str, Any]: project_key = _text(project_key) if not project_key: raise ValueError("project_key is required") ensure_topology_cache_tables() with Session(sql_engine()) as session: if not _has_topology_cache(session, project_key): return { "project_key": project_key, "topologies": [], "total": 0, "mcp_note": ( f"Project '{project_key}' has no cached topologies. Refresh it via " f"GET /topology/cache/refresh?project_key={project_key}. If the refresh " "already succeeded, the upstream project likely has no topology data." ), } group_rows = _load_group_rows(session, project_key) registry_rows = _load_registry_rows(session, project_key) group_path_map = _group_path_map(group_rows) allowed_group_ids: set[int] | None = None if group_id is not None: allowed_group_ids = _collect_descendant_group_ids(group_rows, group_id) topologies: list[dict[str, Any]] = [] for registry_row in registry_rows: if ( allowed_group_ids is not None and registry_row.group_id not in allowed_group_ids ): continue if ( object_type_code is not None and registry_row.object_type_code != object_type_code ): continue topologies.append( { "topology_id": registry_row.topology_id, "topology_name": registry_row.topology_name, "topology_type": registry_row.topology_type, "object_type_code": registry_row.object_type_code, "group_id": registry_row.group_id, "group_path_text": group_path_map.get(registry_row.group_id or -1, ""), "root_shape": registry_row.root_shape, "refreshed_at": registry_row.refreshed_at, } ) return { "project_key": project_key, "topologies": topologies, "total": len(topologies), } def _get_registry_or_error( session: Session, project_key: str, topology_id: int ) -> TopologyRegistry: registry_row = session.scalar( select(TopologyRegistry).where( TopologyRegistry.project_key == project_key, TopologyRegistry.topology_id == topology_id, ) ) if registry_row is None: raise ValueError(f"topology_id not found in cache: {topology_id}") return registry_row def _get_node_or_error( session: Session, project_key: str, topology_id: int, node_id: str ) -> TopologyNode: node_row = session.scalar( select(TopologyNode).where( TopologyNode.project_key == project_key, TopologyNode.topology_id == topology_id, TopologyNode.node_id == node_id, ) ) if node_row is None: raise ValueError(f"node_id not found in cache: {node_id}") return node_row def _node_payload(node_row: TopologyNode) -> dict[str, Any]: return { "node_id": node_row.node_id, "node_name": node_row.node_name, "level": node_row.level, "parent_node_id": node_row.parent_node_id, "refer_id": node_row.refer_id, "refer_level": node_row.refer_level, "is_virtual": bool(node_row.is_virtual), "path_text": node_row.path_text, "child_count": node_row.child_count, } def _load_node_map( session: Session, project_key: str, topology_id: int ) -> dict[str, TopologyNode]: rows = session.scalars( select(TopologyNode).where( TopologyNode.project_key == project_key, TopologyNode.topology_id == topology_id, ) ) return {row.node_id: row for row in rows} def _load_adjacency( session: Session, project_key: str, topology_id: int ) -> tuple[dict[str, list[str]], dict[str, list[str]]]: edges = session.scalars( select(TopologyEdge) .where( TopologyEdge.project_key == project_key, TopologyEdge.topology_id == topology_id, ) .order_by(TopologyEdge.sort_index.asc(), TopologyEdge.id.asc()) ) parents_by_node: dict[str, list[str]] = defaultdict(list) children_by_node: dict[str, list[str]] = defaultdict(list) for edge in edges: parents_by_node[edge.target_node_id].append(edge.source_node_id) children_by_node[edge.source_node_id].append(edge.target_node_id) return parents_by_node, children_by_node def _dedupe_preserve_order(node_ids: list[str]) -> list[str]: result: list[str] = [] seen: set[str] = set() for node_id in node_ids: if node_id in seen: continue seen.add(node_id) result.append(node_id) return result def _node_list_payload( node_map: dict[str, TopologyNode], node_ids: list[str] ) -> list[dict[str, Any]]: payload: list[dict[str, Any]] = [] for node_id in _dedupe_preserve_order(node_ids): node_row = node_map.get(node_id) if node_row is None: continue payload.append(_node_payload(node_row)) return payload def _collect_ancestors( node_map: dict[str, TopologyNode], parents_by_node: dict[str, list[str]], node_id: str, depth_limit: int, ) -> list[dict[str, Any]]: if depth_limit <= 0: return [] visited: dict[str, int] = {} queue: deque[tuple[str, int]] = deque( (parent_id, 1) for parent_id in parents_by_node.get(node_id, []) ) while queue: current_node_id, distance = queue.popleft() if distance > depth_limit: continue previous_distance = visited.get(current_node_id) if previous_distance is not None and previous_distance <= distance: continue visited[current_node_id] = distance for parent_id in parents_by_node.get(current_node_id, []): queue.append((parent_id, distance + 1)) ordered_ids = sorted( visited, key=lambda item: ( -visited[item], node_map[item].path_text or node_map[item].node_name, item, ), ) result: list[dict[str, Any]] = [] for current_node_id in ordered_ids: node_payload = _node_payload(node_map[current_node_id]) node_payload["distance"] = visited[current_node_id] result.append(node_payload) return result def _collect_descendants( node_map: dict[str, TopologyNode], children_by_node: dict[str, list[str]], node_id: str, depth_limit: int, ) -> list[dict[str, Any]]: if depth_limit <= 0: return [] visited: dict[str, int] = {} queue: deque[tuple[str, int]] = deque( (child_id, 1) for child_id in children_by_node.get(node_id, []) ) while queue: current_node_id, distance = queue.popleft() if distance > depth_limit: continue previous_distance = visited.get(current_node_id) if previous_distance is not None and previous_distance <= distance: continue visited[current_node_id] = distance for child_id in children_by_node.get(current_node_id, []): queue.append((child_id, distance + 1)) ordered_ids = sorted( visited, key=lambda item: ( visited[item], node_map[item].path_text or node_map[item].node_name, item, ), ) result: list[dict[str, Any]] = [] for current_node_id in ordered_ids: node_payload = _node_payload(node_map[current_node_id]) node_payload["distance"] = visited[current_node_id] result.append(node_payload) return result def _topology_metadata_payload( registry_row: TopologyRegistry, group_path_text: str ) -> dict[str, Any]: return { "topology_id": registry_row.topology_id, "topology_name": registry_row.topology_name, "topology_type": registry_row.topology_type, "object_type_code": registry_row.object_type_code, "group_id": registry_row.group_id, "group_path_text": group_path_text, "root_shape": registry_row.root_shape, } def get_topology_node( project_key: str, topology_id: int, node_id: str, *, include_siblings: bool = True, include_children: bool = True, ) -> dict[str, Any]: project_key = _text(project_key) node_id = _text(node_id) if not project_key: raise ValueError("project_key is required") if not node_id: raise ValueError("node_id is required") ensure_topology_cache_tables() with Session(sql_engine()) as session: _require_topology_cache(session, project_key) group_rows = _load_group_rows(session, project_key) group_path_map = _group_path_map(group_rows) registry_row = _get_registry_or_error(session, project_key, topology_id) node_row = _get_node_or_error(session, project_key, topology_id, node_id) node_map = _load_node_map(session, project_key, topology_id) parents_by_node, children_by_node = _load_adjacency( session, project_key, topology_id ) parent_ids = parents_by_node.get(node_id, []) child_ids = children_by_node.get(node_id, []) if include_children else [] sibling_ids: list[str] = [] if include_siblings and len(parent_ids) == 1: sibling_ids = [ candidate for candidate in children_by_node.get(parent_ids[0], []) if candidate != node_id ] return { "topology": _topology_metadata_payload( registry_row, group_path_map.get(registry_row.group_id or -1, "") ), "node": _node_payload(node_row), "parents": _node_list_payload(node_map, parent_ids), "children": _node_list_payload(node_map, child_ids), "siblings": _node_list_payload(node_map, sibling_ids), } def find_topology_context( project_key: str, entity_type: str, entity_id: int, *, topology_id: int | None = None, include_siblings: bool = True, ancestor_depth: int = 5, descendant_depth: int = 2, ) -> dict[str, Any]: project_key = _text(project_key) normalized_entity_type = _normalize_entity_type(entity_type) if not project_key: raise ValueError("project_key is required") if entity_id <= 0: raise ValueError("entity_id must be a positive integer") ensure_topology_cache_tables() with Session(sql_engine()) as session: _require_topology_cache(session, project_key) group_rows = _load_group_rows(session, project_key) group_path_map = _group_path_map(group_rows) query = select(TopologyEntityIndex).where( TopologyEntityIndex.project_key == project_key, TopologyEntityIndex.entity_type == normalized_entity_type, TopologyEntityIndex.entity_id == entity_id, ) if topology_id is not None: query = query.where(TopologyEntityIndex.topology_id == topology_id) index_rows = list( session.scalars( query.order_by( TopologyEntityIndex.depth.desc(), TopologyEntityIndex.topology_id.asc(), ) ) ) matches: list[dict[str, Any]] = [] topology_node_maps: dict[int, dict[str, TopologyNode]] = {} topology_adjacency: dict[ int, tuple[dict[str, list[str]], dict[str, list[str]]] ] = {} topology_registry_rows: dict[int, TopologyRegistry] = {} for index_row in index_rows: current_topology_id = index_row.topology_id if current_topology_id not in topology_registry_rows: topology_registry_rows[current_topology_id] = _get_registry_or_error( session, project_key, current_topology_id ) topology_node_maps[current_topology_id] = _load_node_map( session, project_key, current_topology_id ) topology_adjacency[current_topology_id] = _load_adjacency( session, project_key, current_topology_id ) node_map = topology_node_maps[current_topology_id] node_row = node_map.get(index_row.node_id) if node_row is None: continue parents_by_node, children_by_node = topology_adjacency[current_topology_id] parent_ids = parents_by_node.get(index_row.node_id, []) child_ids = children_by_node.get(index_row.node_id, []) sibling_ids: list[str] = [] if include_siblings and len(parent_ids) == 1: sibling_ids = [ candidate for candidate in children_by_node.get(parent_ids[0], []) if candidate != index_row.node_id ] matches.append( { "topology": _topology_metadata_payload( topology_registry_rows[current_topology_id], group_path_map.get( topology_registry_rows[current_topology_id].group_id or -1, "", ), ), "self": _node_payload(node_row), "parents": _node_list_payload(node_map, parent_ids), "children": _node_list_payload(node_map, child_ids), "ancestors": _collect_ancestors( node_map, parents_by_node, index_row.node_id, max(ancestor_depth, 0), ), "descendants": _collect_descendants( node_map, children_by_node, index_row.node_id, max(descendant_depth, 0), ), "siblings": _node_list_payload(node_map, sibling_ids), } ) return { "query": { "project_key": project_key, "entity_type": normalized_entity_type, "entity_id": entity_id, "topology_id": topology_id, }, "matches": matches, "total_matches": len(matches), }