import json from dataclasses import dataclass, field from datetime import datetime from typing import Any, Dict, List, Optional from sqlmodel import Session, select from models.platform import ManagedNodeRecord @dataclass(frozen=True) class ManagedNode: node_id: str display_name: str base_url: str = "" enabled: bool = True auth_token: str = "" metadata: Dict[str, Any] = field(default_factory=dict) capabilities: Dict[str, Any] = field(default_factory=dict) resources: Dict[str, Any] = field(default_factory=dict) last_seen_at: Optional[str] = None class NodeRegistryService: def __init__(self) -> None: self._nodes: Dict[str, ManagedNode] = {} def register_node(self, node: ManagedNode) -> None: self._nodes[str(node.node_id or "").strip().lower()] = self._normalize_node(node) def list_nodes(self) -> List[ManagedNode]: return [self._nodes[key] for key in sorted(self._nodes.keys())] def get_node(self, node_id: str) -> Optional[ManagedNode]: key = str(node_id or "").strip().lower() if not key: return None return self._nodes.get(key) def require_node(self, node_id: str) -> ManagedNode: node = self.get_node(node_id) if node is None: raise ValueError(f"Managed node not found: {node_id}") if not node.enabled: raise ValueError(f"Managed node is disabled: {node_id}") return node def load_from_session(self, session: Session) -> List[ManagedNode]: rows = session.exec(select(ManagedNodeRecord)).all() self._nodes = {} for row in rows: self.register_node(self._row_to_node(row)) return self.list_nodes() def upsert_node(self, session: Session, node: ManagedNode) -> ManagedNode: normalized = self._normalize_node(node) row = session.get(ManagedNodeRecord, normalized.node_id) if row is None: row = ManagedNodeRecord(node_id=normalized.node_id) metadata = dict(normalized.metadata or {}) row.display_name = normalized.display_name or normalized.node_id row.base_url = normalized.base_url or "" row.enabled = bool(normalized.enabled) row.auth_token = normalized.auth_token or "" row.transport_kind = str(metadata.get("transport_kind") or "edge").strip().lower() or "edge" row.runtime_kind = str(metadata.get("runtime_kind") or "docker").strip().lower() or "docker" row.core_adapter = str(metadata.get("core_adapter") or "nanobot").strip().lower() or "nanobot" row.metadata_json = json.dumps(metadata, ensure_ascii=False, sort_keys=True) row.capabilities_json = json.dumps(dict(normalized.capabilities or {}), ensure_ascii=False, sort_keys=True) row.resources_json = json.dumps(dict(normalized.resources or {}), ensure_ascii=False, sort_keys=True) row.last_seen_at = self._parse_datetime(normalized.last_seen_at) or row.last_seen_at row.updated_at = datetime.utcnow() if row.created_at is None: row.created_at = datetime.utcnow() session.add(row) session.commit() session.refresh(row) self.register_node(self._row_to_node(row)) return self.require_node(normalized.node_id) def mark_node_seen( self, session: Session, *, node_id: str, display_name: Optional[str] = None, capabilities: Optional[Dict[str, Any]] = None, resources: Optional[Dict[str, Any]] = None, ) -> ManagedNode: row = session.get(ManagedNodeRecord, str(node_id or "").strip().lower()) if row is None: raise ValueError(f"Managed node not found: {node_id}") if str(display_name or "").strip(): row.display_name = str(display_name).strip() if capabilities is not None: row.capabilities_json = json.dumps(dict(capabilities or {}), ensure_ascii=False, sort_keys=True) if resources is not None: row.resources_json = json.dumps(dict(resources or {}), ensure_ascii=False, sort_keys=True) row.last_seen_at = datetime.utcnow() row.updated_at = datetime.utcnow() session.add(row) session.commit() session.refresh(row) self.register_node(self._row_to_node(row)) return self.require_node(str(node_id or "").strip().lower()) def delete_node(self, session: Session, node_id: str) -> None: key = str(node_id or "").strip().lower() if not key: raise ValueError("node_id is required") row = session.get(ManagedNodeRecord, key) if row is None: raise ValueError(f"Managed node not found: {node_id}") session.delete(row) session.commit() self._nodes.pop(key, None) @staticmethod def _normalize_node(node: ManagedNode) -> ManagedNode: metadata = dict(node.metadata or {}) normalized = ManagedNode( node_id=str(node.node_id or "").strip().lower(), display_name=str(node.display_name or node.node_id or "").strip() or str(node.node_id or "").strip().lower(), base_url=str(node.base_url or "").strip(), enabled=bool(node.enabled), auth_token=str(node.auth_token or "").strip(), metadata=metadata, capabilities=dict(node.capabilities or {}), resources=dict(node.resources or {}), last_seen_at=str(node.last_seen_at or "").strip() or None, ) return normalized @staticmethod def _row_to_node(row: ManagedNodeRecord) -> ManagedNode: metadata: Dict[str, Any] = {} capabilities: Dict[str, Any] = {} try: loaded = json.loads(str(row.metadata_json or "{}")) if isinstance(loaded, dict): metadata = loaded except Exception: metadata = {} try: loaded_capabilities = json.loads(str(row.capabilities_json or "{}")) if isinstance(loaded_capabilities, dict): capabilities = loaded_capabilities except Exception: capabilities = {} resources: Dict[str, Any] = {} try: loaded_resources = json.loads(str(row.resources_json or "{}")) if isinstance(loaded_resources, dict): resources = loaded_resources except Exception: resources = {} metadata.setdefault("transport_kind", str(row.transport_kind or "").strip().lower() or "edge") metadata.setdefault("runtime_kind", str(row.runtime_kind or "").strip().lower() or "docker") metadata.setdefault("core_adapter", str(row.core_adapter or "").strip().lower() or "nanobot") return ManagedNode( node_id=str(row.node_id or "").strip().lower(), display_name=str(row.display_name or row.node_id or "").strip(), base_url=str(row.base_url or "").strip(), enabled=bool(row.enabled), auth_token=str(row.auth_token or "").strip(), metadata=metadata, capabilities=capabilities, resources=resources, last_seen_at=(row.last_seen_at.isoformat() + "Z") if row.last_seen_at else None, ) @staticmethod def _parse_datetime(value: Optional[str]) -> Optional[datetime]: raw = str(value or "").strip() if not raw: return None normalized = raw[:-1] if raw.endswith("Z") else raw try: return datetime.fromisoformat(normalized) except Exception: return None