dashboard-nanobot/backend/services/bot_config_state_service.py

321 lines
13 KiB
Python
Raw Permalink Normal View History

2026-03-26 16:12:46 +00:00
import json
import os
import re
from datetime import datetime
from typing import Any, Callable, Dict, List
from fastapi import HTTPException
from sqlmodel import Session
from models.bot import BotInstance
ReadEdgeStateData = Callable[..., Dict[str, Any]]
WriteEdgeStateData = Callable[..., bool]
ReadBotConfig = Callable[[str], Dict[str, Any]]
WriteBotConfig = Callable[[str, Dict[str, Any]], None]
InvalidateBotCache = Callable[[str], None]
PathResolver = Callable[[str], str]
NormalizeEnvParams = Callable[[Any], Dict[str, str]]
class BotConfigStateService:
_MCP_SERVER_NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,64}$")
def __init__(
self,
*,
read_edge_state_data: ReadEdgeStateData,
write_edge_state_data: WriteEdgeStateData,
read_bot_config: ReadBotConfig,
write_bot_config: WriteBotConfig,
invalidate_bot_detail_cache: InvalidateBotCache,
env_store_path: PathResolver,
cron_store_path: PathResolver,
normalize_env_params: NormalizeEnvParams,
) -> None:
self._read_edge_state_data = read_edge_state_data
self._write_edge_state_data = write_edge_state_data
self._read_bot_config = read_bot_config
self._write_bot_config = write_bot_config
self._invalidate_bot_detail_cache = invalidate_bot_detail_cache
self._env_store_path = env_store_path
self._cron_store_path = cron_store_path
self._normalize_env_params = normalize_env_params
def _require_bot(self, *, session: Session, bot_id: str) -> BotInstance:
bot = session.get(BotInstance, bot_id)
if not bot:
raise HTTPException(status_code=404, detail="Bot not found")
return bot
def read_env_store(self, bot_id: str) -> Dict[str, str]:
data = self._read_edge_state_data(bot_id=bot_id, state_key="env", default_payload={})
if data:
return self._normalize_env_params(data)
path = self._env_store_path(bot_id)
if not os.path.isfile(path):
return {}
try:
with open(path, "r", encoding="utf-8") as file:
payload = json.load(file)
return self._normalize_env_params(payload)
except Exception:
return {}
def write_env_store(self, bot_id: str, env_params: Dict[str, str]) -> None:
normalized_env = self._normalize_env_params(env_params)
if self._write_edge_state_data(bot_id=bot_id, state_key="env", data=normalized_env):
return
path = self._env_store_path(bot_id)
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp_path = f"{path}.tmp"
with open(tmp_path, "w", encoding="utf-8") as file:
json.dump(normalized_env, file, ensure_ascii=False, indent=2)
os.replace(tmp_path, path)
def get_env_params(self, bot_id: str) -> Dict[str, Any]:
return {
"bot_id": bot_id,
"env_params": self.read_env_store(bot_id),
}
def get_env_params_for_bot(self, *, session: Session, bot_id: str) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.get_env_params(bot_id)
def update_env_params(self, bot_id: str, env_params: Any) -> Dict[str, Any]:
normalized = self._normalize_env_params(env_params)
self.write_env_store(bot_id, normalized)
self._invalidate_bot_detail_cache(bot_id)
return {
"status": "updated",
"bot_id": bot_id,
"env_params": normalized,
"restart_required": True,
}
def update_env_params_for_bot(self, *, session: Session, bot_id: str, env_params: Any) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.update_env_params(bot_id, env_params)
def normalize_mcp_servers(self, raw: Any) -> Dict[str, Dict[str, Any]]:
if not isinstance(raw, dict):
return {}
rows: Dict[str, Dict[str, Any]] = {}
for server_name, server_cfg in raw.items():
name = str(server_name or "").strip()
if not name or not self._MCP_SERVER_NAME_RE.fullmatch(name):
continue
if not isinstance(server_cfg, dict):
continue
url = str(server_cfg.get("url") or "").strip()
if not url:
continue
transport_type = str(server_cfg.get("type") or "streamableHttp").strip()
if transport_type not in {"streamableHttp", "sse"}:
transport_type = "streamableHttp"
headers_raw = server_cfg.get("headers")
headers: Dict[str, str] = {}
if isinstance(headers_raw, dict):
for key, value in headers_raw.items():
header_key = str(key or "").strip()
if not header_key:
continue
headers[header_key] = str(value or "").strip()
timeout_raw = server_cfg.get("toolTimeout", 60)
try:
timeout = int(timeout_raw)
except Exception:
timeout = 60
timeout = max(1, min(timeout, 600))
rows[name] = {
"type": transport_type,
"url": url,
"headers": headers,
"toolTimeout": timeout,
}
return rows
def _merge_mcp_servers_preserving_extras(
self,
current_raw: Any,
normalized: Dict[str, Dict[str, Any]],
) -> Dict[str, Dict[str, Any]]:
current_map = current_raw if isinstance(current_raw, dict) else {}
merged: Dict[str, Dict[str, Any]] = {}
for name, normalized_cfg in normalized.items():
base = current_map.get(name)
base_cfg = dict(base) if isinstance(base, dict) else {}
next_cfg = dict(base_cfg)
next_cfg.update(normalized_cfg)
merged[name] = next_cfg
return merged
def _sanitize_mcp_servers_in_config_data(self, config_data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
if not isinstance(config_data, dict):
return {}
tools_cfg = config_data.get("tools")
if not isinstance(tools_cfg, dict):
tools_cfg = {}
current_raw = tools_cfg.get("mcpServers")
normalized = self.normalize_mcp_servers(current_raw)
merged = self._merge_mcp_servers_preserving_extras(current_raw, normalized)
tools_cfg["mcpServers"] = merged
config_data["tools"] = tools_cfg
return merged
def get_mcp_config(self, bot_id: str) -> Dict[str, Any]:
config_data = self._read_bot_config(bot_id)
tools_cfg = config_data.get("tools") if isinstance(config_data, dict) else {}
if not isinstance(tools_cfg, dict):
tools_cfg = {}
mcp_servers = self.normalize_mcp_servers(tools_cfg.get("mcpServers"))
return {
"bot_id": bot_id,
"mcp_servers": mcp_servers,
"locked_servers": [],
"restart_required": True,
}
def get_mcp_config_for_bot(self, *, session: Session, bot_id: str) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.get_mcp_config(bot_id)
def update_mcp_config(self, bot_id: str, mcp_servers: Any) -> Dict[str, Any]:
config_data = self._read_bot_config(bot_id)
if not isinstance(config_data, dict):
config_data = {}
tools_cfg = config_data.get("tools")
if not isinstance(tools_cfg, dict):
tools_cfg = {}
normalized_mcp_servers = self.normalize_mcp_servers(mcp_servers or {})
current_mcp_servers = tools_cfg.get("mcpServers")
merged_mcp_servers = self._merge_mcp_servers_preserving_extras(current_mcp_servers, normalized_mcp_servers)
tools_cfg["mcpServers"] = merged_mcp_servers
config_data["tools"] = tools_cfg
sanitized_after_save = self._sanitize_mcp_servers_in_config_data(config_data)
self._write_bot_config(bot_id, config_data)
self._invalidate_bot_detail_cache(bot_id)
return {
"status": "updated",
"bot_id": bot_id,
"mcp_servers": self.normalize_mcp_servers(sanitized_after_save),
"locked_servers": [],
"restart_required": True,
}
def update_mcp_config_for_bot(self, *, session: Session, bot_id: str, mcp_servers: Any) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.update_mcp_config(bot_id, mcp_servers)
def read_cron_store(self, bot_id: str) -> Dict[str, Any]:
data = self._read_edge_state_data(
bot_id=bot_id,
state_key="cron",
default_payload={"version": 1, "jobs": []},
)
if isinstance(data, dict) and data:
jobs = data.get("jobs")
if not isinstance(jobs, list):
jobs = []
try:
version = int(data.get("version", 1) or 1)
except Exception:
version = 1
return {"version": max(1, version), "jobs": jobs}
path = self._cron_store_path(bot_id)
if not os.path.isfile(path):
return {"version": 1, "jobs": []}
try:
with open(path, "r", encoding="utf-8") as file:
payload = json.load(file)
if not isinstance(payload, dict):
return {"version": 1, "jobs": []}
jobs = payload.get("jobs")
if not isinstance(jobs, list):
payload["jobs"] = []
if "version" not in payload:
payload["version"] = 1
return payload
except Exception:
return {"version": 1, "jobs": []}
def write_cron_store(self, bot_id: str, store: Dict[str, Any]) -> None:
normalized_store = dict(store if isinstance(store, dict) else {})
jobs = normalized_store.get("jobs")
if not isinstance(jobs, list):
normalized_store["jobs"] = []
try:
normalized_store["version"] = max(1, int(normalized_store.get("version", 1) or 1))
except Exception:
normalized_store["version"] = 1
if self._write_edge_state_data(bot_id=bot_id, state_key="cron", data=normalized_store):
return
path = self._cron_store_path(bot_id)
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp_path = f"{path}.tmp"
with open(tmp_path, "w", encoding="utf-8") as file:
json.dump(normalized_store, file, ensure_ascii=False, indent=2)
os.replace(tmp_path, path)
def list_cron_jobs(self, bot_id: str, include_disabled: bool = True) -> Dict[str, Any]:
store = self.read_cron_store(bot_id)
rows = []
for row in store.get("jobs", []):
if not isinstance(row, dict):
continue
enabled = bool(row.get("enabled", True))
if not include_disabled and not enabled:
continue
rows.append(row)
rows.sort(key=lambda value: int(((value.get("state") or {}).get("nextRunAtMs")) or 2**62))
return {"bot_id": bot_id, "version": int(store.get("version", 1) or 1), "jobs": rows}
def list_cron_jobs_for_bot(self, *, session: Session, bot_id: str, include_disabled: bool = True) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.list_cron_jobs(bot_id, include_disabled=include_disabled)
def stop_cron_job(self, bot_id: str, job_id: str) -> Dict[str, Any]:
store = self.read_cron_store(bot_id)
jobs = store.get("jobs", [])
if not isinstance(jobs, list):
jobs = []
found = None
for row in jobs:
if isinstance(row, dict) and str(row.get("id")) == job_id:
found = row
break
if not found:
raise HTTPException(status_code=404, detail="Cron job not found")
found["enabled"] = False
found["updatedAtMs"] = int(datetime.utcnow().timestamp() * 1000)
self.write_cron_store(bot_id, {"version": int(store.get("version", 1) or 1), "jobs": jobs})
return {"status": "stopped", "job_id": job_id}
def stop_cron_job_for_bot(self, *, session: Session, bot_id: str, job_id: str) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.stop_cron_job(bot_id, job_id)
def delete_cron_job(self, bot_id: str, job_id: str) -> Dict[str, Any]:
store = self.read_cron_store(bot_id)
jobs = store.get("jobs", [])
if not isinstance(jobs, list):
jobs = []
kept = [row for row in jobs if not (isinstance(row, dict) and str(row.get("id")) == job_id)]
if len(kept) == len(jobs):
raise HTTPException(status_code=404, detail="Cron job not found")
self.write_cron_store(bot_id, {"version": int(store.get("version", 1) or 1), "jobs": kept})
return {"status": "deleted", "job_id": job_id}
def delete_cron_job_for_bot(self, *, session: Session, bot_id: str, job_id: str) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.delete_cron_job(bot_id, job_id)