dashboard-nanobot/backend/services/bot_config_state_service.py

321 lines
13 KiB
Python

import json
import os
import re
from datetime import datetime
from typing import Any, Callable, Dict, List
from fastapi import HTTPException
from sqlmodel import Session
from models.bot import BotInstance
ReadEdgeStateData = Callable[..., Dict[str, Any]]
WriteEdgeStateData = Callable[..., bool]
ReadBotConfig = Callable[[str], Dict[str, Any]]
WriteBotConfig = Callable[[str, Dict[str, Any]], None]
InvalidateBotCache = Callable[[str], None]
PathResolver = Callable[[str], str]
NormalizeEnvParams = Callable[[Any], Dict[str, str]]
class BotConfigStateService:
_MCP_SERVER_NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,64}$")
def __init__(
self,
*,
read_edge_state_data: ReadEdgeStateData,
write_edge_state_data: WriteEdgeStateData,
read_bot_config: ReadBotConfig,
write_bot_config: WriteBotConfig,
invalidate_bot_detail_cache: InvalidateBotCache,
env_store_path: PathResolver,
cron_store_path: PathResolver,
normalize_env_params: NormalizeEnvParams,
) -> None:
self._read_edge_state_data = read_edge_state_data
self._write_edge_state_data = write_edge_state_data
self._read_bot_config = read_bot_config
self._write_bot_config = write_bot_config
self._invalidate_bot_detail_cache = invalidate_bot_detail_cache
self._env_store_path = env_store_path
self._cron_store_path = cron_store_path
self._normalize_env_params = normalize_env_params
def _require_bot(self, *, session: Session, bot_id: str) -> BotInstance:
bot = session.get(BotInstance, bot_id)
if not bot:
raise HTTPException(status_code=404, detail="Bot not found")
return bot
def read_env_store(self, bot_id: str) -> Dict[str, str]:
data = self._read_edge_state_data(bot_id=bot_id, state_key="env", default_payload={})
if data:
return self._normalize_env_params(data)
path = self._env_store_path(bot_id)
if not os.path.isfile(path):
return {}
try:
with open(path, "r", encoding="utf-8") as file:
payload = json.load(file)
return self._normalize_env_params(payload)
except Exception:
return {}
def write_env_store(self, bot_id: str, env_params: Dict[str, str]) -> None:
normalized_env = self._normalize_env_params(env_params)
if self._write_edge_state_data(bot_id=bot_id, state_key="env", data=normalized_env):
return
path = self._env_store_path(bot_id)
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp_path = f"{path}.tmp"
with open(tmp_path, "w", encoding="utf-8") as file:
json.dump(normalized_env, file, ensure_ascii=False, indent=2)
os.replace(tmp_path, path)
def get_env_params(self, bot_id: str) -> Dict[str, Any]:
return {
"bot_id": bot_id,
"env_params": self.read_env_store(bot_id),
}
def get_env_params_for_bot(self, *, session: Session, bot_id: str) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.get_env_params(bot_id)
def update_env_params(self, bot_id: str, env_params: Any) -> Dict[str, Any]:
normalized = self._normalize_env_params(env_params)
self.write_env_store(bot_id, normalized)
self._invalidate_bot_detail_cache(bot_id)
return {
"status": "updated",
"bot_id": bot_id,
"env_params": normalized,
"restart_required": True,
}
def update_env_params_for_bot(self, *, session: Session, bot_id: str, env_params: Any) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.update_env_params(bot_id, env_params)
def normalize_mcp_servers(self, raw: Any) -> Dict[str, Dict[str, Any]]:
if not isinstance(raw, dict):
return {}
rows: Dict[str, Dict[str, Any]] = {}
for server_name, server_cfg in raw.items():
name = str(server_name or "").strip()
if not name or not self._MCP_SERVER_NAME_RE.fullmatch(name):
continue
if not isinstance(server_cfg, dict):
continue
url = str(server_cfg.get("url") or "").strip()
if not url:
continue
transport_type = str(server_cfg.get("type") or "streamableHttp").strip()
if transport_type not in {"streamableHttp", "sse"}:
transport_type = "streamableHttp"
headers_raw = server_cfg.get("headers")
headers: Dict[str, str] = {}
if isinstance(headers_raw, dict):
for key, value in headers_raw.items():
header_key = str(key or "").strip()
if not header_key:
continue
headers[header_key] = str(value or "").strip()
timeout_raw = server_cfg.get("toolTimeout", 60)
try:
timeout = int(timeout_raw)
except Exception:
timeout = 60
timeout = max(1, min(timeout, 600))
rows[name] = {
"type": transport_type,
"url": url,
"headers": headers,
"toolTimeout": timeout,
}
return rows
def _merge_mcp_servers_preserving_extras(
self,
current_raw: Any,
normalized: Dict[str, Dict[str, Any]],
) -> Dict[str, Dict[str, Any]]:
current_map = current_raw if isinstance(current_raw, dict) else {}
merged: Dict[str, Dict[str, Any]] = {}
for name, normalized_cfg in normalized.items():
base = current_map.get(name)
base_cfg = dict(base) if isinstance(base, dict) else {}
next_cfg = dict(base_cfg)
next_cfg.update(normalized_cfg)
merged[name] = next_cfg
return merged
def _sanitize_mcp_servers_in_config_data(self, config_data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
if not isinstance(config_data, dict):
return {}
tools_cfg = config_data.get("tools")
if not isinstance(tools_cfg, dict):
tools_cfg = {}
current_raw = tools_cfg.get("mcpServers")
normalized = self.normalize_mcp_servers(current_raw)
merged = self._merge_mcp_servers_preserving_extras(current_raw, normalized)
tools_cfg["mcpServers"] = merged
config_data["tools"] = tools_cfg
return merged
def get_mcp_config(self, bot_id: str) -> Dict[str, Any]:
config_data = self._read_bot_config(bot_id)
tools_cfg = config_data.get("tools") if isinstance(config_data, dict) else {}
if not isinstance(tools_cfg, dict):
tools_cfg = {}
mcp_servers = self.normalize_mcp_servers(tools_cfg.get("mcpServers"))
return {
"bot_id": bot_id,
"mcp_servers": mcp_servers,
"locked_servers": [],
"restart_required": True,
}
def get_mcp_config_for_bot(self, *, session: Session, bot_id: str) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.get_mcp_config(bot_id)
def update_mcp_config(self, bot_id: str, mcp_servers: Any) -> Dict[str, Any]:
config_data = self._read_bot_config(bot_id)
if not isinstance(config_data, dict):
config_data = {}
tools_cfg = config_data.get("tools")
if not isinstance(tools_cfg, dict):
tools_cfg = {}
normalized_mcp_servers = self.normalize_mcp_servers(mcp_servers or {})
current_mcp_servers = tools_cfg.get("mcpServers")
merged_mcp_servers = self._merge_mcp_servers_preserving_extras(current_mcp_servers, normalized_mcp_servers)
tools_cfg["mcpServers"] = merged_mcp_servers
config_data["tools"] = tools_cfg
sanitized_after_save = self._sanitize_mcp_servers_in_config_data(config_data)
self._write_bot_config(bot_id, config_data)
self._invalidate_bot_detail_cache(bot_id)
return {
"status": "updated",
"bot_id": bot_id,
"mcp_servers": self.normalize_mcp_servers(sanitized_after_save),
"locked_servers": [],
"restart_required": True,
}
def update_mcp_config_for_bot(self, *, session: Session, bot_id: str, mcp_servers: Any) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.update_mcp_config(bot_id, mcp_servers)
def read_cron_store(self, bot_id: str) -> Dict[str, Any]:
data = self._read_edge_state_data(
bot_id=bot_id,
state_key="cron",
default_payload={"version": 1, "jobs": []},
)
if isinstance(data, dict) and data:
jobs = data.get("jobs")
if not isinstance(jobs, list):
jobs = []
try:
version = int(data.get("version", 1) or 1)
except Exception:
version = 1
return {"version": max(1, version), "jobs": jobs}
path = self._cron_store_path(bot_id)
if not os.path.isfile(path):
return {"version": 1, "jobs": []}
try:
with open(path, "r", encoding="utf-8") as file:
payload = json.load(file)
if not isinstance(payload, dict):
return {"version": 1, "jobs": []}
jobs = payload.get("jobs")
if not isinstance(jobs, list):
payload["jobs"] = []
if "version" not in payload:
payload["version"] = 1
return payload
except Exception:
return {"version": 1, "jobs": []}
def write_cron_store(self, bot_id: str, store: Dict[str, Any]) -> None:
normalized_store = dict(store if isinstance(store, dict) else {})
jobs = normalized_store.get("jobs")
if not isinstance(jobs, list):
normalized_store["jobs"] = []
try:
normalized_store["version"] = max(1, int(normalized_store.get("version", 1) or 1))
except Exception:
normalized_store["version"] = 1
if self._write_edge_state_data(bot_id=bot_id, state_key="cron", data=normalized_store):
return
path = self._cron_store_path(bot_id)
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp_path = f"{path}.tmp"
with open(tmp_path, "w", encoding="utf-8") as file:
json.dump(normalized_store, file, ensure_ascii=False, indent=2)
os.replace(tmp_path, path)
def list_cron_jobs(self, bot_id: str, include_disabled: bool = True) -> Dict[str, Any]:
store = self.read_cron_store(bot_id)
rows = []
for row in store.get("jobs", []):
if not isinstance(row, dict):
continue
enabled = bool(row.get("enabled", True))
if not include_disabled and not enabled:
continue
rows.append(row)
rows.sort(key=lambda value: int(((value.get("state") or {}).get("nextRunAtMs")) or 2**62))
return {"bot_id": bot_id, "version": int(store.get("version", 1) or 1), "jobs": rows}
def list_cron_jobs_for_bot(self, *, session: Session, bot_id: str, include_disabled: bool = True) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.list_cron_jobs(bot_id, include_disabled=include_disabled)
def stop_cron_job(self, bot_id: str, job_id: str) -> Dict[str, Any]:
store = self.read_cron_store(bot_id)
jobs = store.get("jobs", [])
if not isinstance(jobs, list):
jobs = []
found = None
for row in jobs:
if isinstance(row, dict) and str(row.get("id")) == job_id:
found = row
break
if not found:
raise HTTPException(status_code=404, detail="Cron job not found")
found["enabled"] = False
found["updatedAtMs"] = int(datetime.utcnow().timestamp() * 1000)
self.write_cron_store(bot_id, {"version": int(store.get("version", 1) or 1), "jobs": jobs})
return {"status": "stopped", "job_id": job_id}
def stop_cron_job_for_bot(self, *, session: Session, bot_id: str, job_id: str) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.stop_cron_job(bot_id, job_id)
def delete_cron_job(self, bot_id: str, job_id: str) -> Dict[str, Any]:
store = self.read_cron_store(bot_id)
jobs = store.get("jobs", [])
if not isinstance(jobs, list):
jobs = []
kept = [row for row in jobs if not (isinstance(row, dict) and str(row.get("id")) == job_id)]
if len(kept) == len(jobs):
raise HTTPException(status_code=404, detail="Cron job not found")
self.write_cron_store(bot_id, {"version": int(store.get("version", 1) or 1), "jobs": kept})
return {"status": "deleted", "job_id": job_id}
def delete_cron_job_for_bot(self, *, session: Session, bot_id: str, job_id: str) -> Dict[str, Any]:
self._require_bot(session=session, bot_id=bot_id)
return self.delete_cron_job(bot_id, job_id)