import json import os import re from datetime import datetime from typing import Any, Callable, Dict, List from fastapi import HTTPException from sqlmodel import Session from models.bot import BotInstance ReadEdgeStateData = Callable[..., Dict[str, Any]] WriteEdgeStateData = Callable[..., bool] ReadBotConfig = Callable[[str], Dict[str, Any]] WriteBotConfig = Callable[[str, Dict[str, Any]], None] InvalidateBotCache = Callable[[str], None] PathResolver = Callable[[str], str] NormalizeEnvParams = Callable[[Any], Dict[str, str]] class BotConfigStateService: _MCP_SERVER_NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,64}$") def __init__( self, *, read_edge_state_data: ReadEdgeStateData, write_edge_state_data: WriteEdgeStateData, read_bot_config: ReadBotConfig, write_bot_config: WriteBotConfig, invalidate_bot_detail_cache: InvalidateBotCache, env_store_path: PathResolver, cron_store_path: PathResolver, normalize_env_params: NormalizeEnvParams, ) -> None: self._read_edge_state_data = read_edge_state_data self._write_edge_state_data = write_edge_state_data self._read_bot_config = read_bot_config self._write_bot_config = write_bot_config self._invalidate_bot_detail_cache = invalidate_bot_detail_cache self._env_store_path = env_store_path self._cron_store_path = cron_store_path self._normalize_env_params = normalize_env_params def _require_bot(self, *, session: Session, bot_id: str) -> BotInstance: bot = session.get(BotInstance, bot_id) if not bot: raise HTTPException(status_code=404, detail="Bot not found") return bot def read_env_store(self, bot_id: str) -> Dict[str, str]: data = self._read_edge_state_data(bot_id=bot_id, state_key="env", default_payload={}) if data: return self._normalize_env_params(data) path = self._env_store_path(bot_id) if not os.path.isfile(path): return {} try: with open(path, "r", encoding="utf-8") as file: payload = json.load(file) return self._normalize_env_params(payload) except Exception: return {} def write_env_store(self, bot_id: str, env_params: Dict[str, str]) -> None: normalized_env = self._normalize_env_params(env_params) if self._write_edge_state_data(bot_id=bot_id, state_key="env", data=normalized_env): return path = self._env_store_path(bot_id) os.makedirs(os.path.dirname(path), exist_ok=True) tmp_path = f"{path}.tmp" with open(tmp_path, "w", encoding="utf-8") as file: json.dump(normalized_env, file, ensure_ascii=False, indent=2) os.replace(tmp_path, path) def get_env_params(self, bot_id: str) -> Dict[str, Any]: return { "bot_id": bot_id, "env_params": self.read_env_store(bot_id), } def get_env_params_for_bot(self, *, session: Session, bot_id: str) -> Dict[str, Any]: self._require_bot(session=session, bot_id=bot_id) return self.get_env_params(bot_id) def update_env_params(self, bot_id: str, env_params: Any) -> Dict[str, Any]: normalized = self._normalize_env_params(env_params) self.write_env_store(bot_id, normalized) self._invalidate_bot_detail_cache(bot_id) return { "status": "updated", "bot_id": bot_id, "env_params": normalized, "restart_required": True, } def update_env_params_for_bot(self, *, session: Session, bot_id: str, env_params: Any) -> Dict[str, Any]: self._require_bot(session=session, bot_id=bot_id) return self.update_env_params(bot_id, env_params) def normalize_mcp_servers(self, raw: Any) -> Dict[str, Dict[str, Any]]: if not isinstance(raw, dict): return {} rows: Dict[str, Dict[str, Any]] = {} for server_name, server_cfg in raw.items(): name = str(server_name or "").strip() if not name or not self._MCP_SERVER_NAME_RE.fullmatch(name): continue if not isinstance(server_cfg, dict): continue url = str(server_cfg.get("url") or "").strip() if not url: continue transport_type = str(server_cfg.get("type") or "streamableHttp").strip() if transport_type not in {"streamableHttp", "sse"}: transport_type = "streamableHttp" headers_raw = server_cfg.get("headers") headers: Dict[str, str] = {} if isinstance(headers_raw, dict): for key, value in headers_raw.items(): header_key = str(key or "").strip() if not header_key: continue headers[header_key] = str(value or "").strip() timeout_raw = server_cfg.get("toolTimeout", 60) try: timeout = int(timeout_raw) except Exception: timeout = 60 timeout = max(1, min(timeout, 600)) rows[name] = { "type": transport_type, "url": url, "headers": headers, "toolTimeout": timeout, } return rows def _merge_mcp_servers_preserving_extras( self, current_raw: Any, normalized: Dict[str, Dict[str, Any]], ) -> Dict[str, Dict[str, Any]]: current_map = current_raw if isinstance(current_raw, dict) else {} merged: Dict[str, Dict[str, Any]] = {} for name, normalized_cfg in normalized.items(): base = current_map.get(name) base_cfg = dict(base) if isinstance(base, dict) else {} next_cfg = dict(base_cfg) next_cfg.update(normalized_cfg) merged[name] = next_cfg return merged def _sanitize_mcp_servers_in_config_data(self, config_data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: if not isinstance(config_data, dict): return {} tools_cfg = config_data.get("tools") if not isinstance(tools_cfg, dict): tools_cfg = {} current_raw = tools_cfg.get("mcpServers") normalized = self.normalize_mcp_servers(current_raw) merged = self._merge_mcp_servers_preserving_extras(current_raw, normalized) tools_cfg["mcpServers"] = merged config_data["tools"] = tools_cfg return merged def get_mcp_config(self, bot_id: str) -> Dict[str, Any]: config_data = self._read_bot_config(bot_id) tools_cfg = config_data.get("tools") if isinstance(config_data, dict) else {} if not isinstance(tools_cfg, dict): tools_cfg = {} mcp_servers = self.normalize_mcp_servers(tools_cfg.get("mcpServers")) return { "bot_id": bot_id, "mcp_servers": mcp_servers, "locked_servers": [], "restart_required": True, } def get_mcp_config_for_bot(self, *, session: Session, bot_id: str) -> Dict[str, Any]: self._require_bot(session=session, bot_id=bot_id) return self.get_mcp_config(bot_id) def update_mcp_config(self, bot_id: str, mcp_servers: Any) -> Dict[str, Any]: config_data = self._read_bot_config(bot_id) if not isinstance(config_data, dict): config_data = {} tools_cfg = config_data.get("tools") if not isinstance(tools_cfg, dict): tools_cfg = {} normalized_mcp_servers = self.normalize_mcp_servers(mcp_servers or {}) current_mcp_servers = tools_cfg.get("mcpServers") merged_mcp_servers = self._merge_mcp_servers_preserving_extras(current_mcp_servers, normalized_mcp_servers) tools_cfg["mcpServers"] = merged_mcp_servers config_data["tools"] = tools_cfg sanitized_after_save = self._sanitize_mcp_servers_in_config_data(config_data) self._write_bot_config(bot_id, config_data) self._invalidate_bot_detail_cache(bot_id) return { "status": "updated", "bot_id": bot_id, "mcp_servers": self.normalize_mcp_servers(sanitized_after_save), "locked_servers": [], "restart_required": True, } def update_mcp_config_for_bot(self, *, session: Session, bot_id: str, mcp_servers: Any) -> Dict[str, Any]: self._require_bot(session=session, bot_id=bot_id) return self.update_mcp_config(bot_id, mcp_servers) def read_cron_store(self, bot_id: str) -> Dict[str, Any]: data = self._read_edge_state_data( bot_id=bot_id, state_key="cron", default_payload={"version": 1, "jobs": []}, ) if isinstance(data, dict) and data: jobs = data.get("jobs") if not isinstance(jobs, list): jobs = [] try: version = int(data.get("version", 1) or 1) except Exception: version = 1 return {"version": max(1, version), "jobs": jobs} path = self._cron_store_path(bot_id) if not os.path.isfile(path): return {"version": 1, "jobs": []} try: with open(path, "r", encoding="utf-8") as file: payload = json.load(file) if not isinstance(payload, dict): return {"version": 1, "jobs": []} jobs = payload.get("jobs") if not isinstance(jobs, list): payload["jobs"] = [] if "version" not in payload: payload["version"] = 1 return payload except Exception: return {"version": 1, "jobs": []} def write_cron_store(self, bot_id: str, store: Dict[str, Any]) -> None: normalized_store = dict(store if isinstance(store, dict) else {}) jobs = normalized_store.get("jobs") if not isinstance(jobs, list): normalized_store["jobs"] = [] try: normalized_store["version"] = max(1, int(normalized_store.get("version", 1) or 1)) except Exception: normalized_store["version"] = 1 if self._write_edge_state_data(bot_id=bot_id, state_key="cron", data=normalized_store): return path = self._cron_store_path(bot_id) os.makedirs(os.path.dirname(path), exist_ok=True) tmp_path = f"{path}.tmp" with open(tmp_path, "w", encoding="utf-8") as file: json.dump(normalized_store, file, ensure_ascii=False, indent=2) os.replace(tmp_path, path) def list_cron_jobs(self, bot_id: str, include_disabled: bool = True) -> Dict[str, Any]: store = self.read_cron_store(bot_id) rows = [] for row in store.get("jobs", []): if not isinstance(row, dict): continue enabled = bool(row.get("enabled", True)) if not include_disabled and not enabled: continue rows.append(row) rows.sort(key=lambda value: int(((value.get("state") or {}).get("nextRunAtMs")) or 2**62)) return {"bot_id": bot_id, "version": int(store.get("version", 1) or 1), "jobs": rows} def list_cron_jobs_for_bot(self, *, session: Session, bot_id: str, include_disabled: bool = True) -> Dict[str, Any]: self._require_bot(session=session, bot_id=bot_id) return self.list_cron_jobs(bot_id, include_disabled=include_disabled) def stop_cron_job(self, bot_id: str, job_id: str) -> Dict[str, Any]: store = self.read_cron_store(bot_id) jobs = store.get("jobs", []) if not isinstance(jobs, list): jobs = [] found = None for row in jobs: if isinstance(row, dict) and str(row.get("id")) == job_id: found = row break if not found: raise HTTPException(status_code=404, detail="Cron job not found") found["enabled"] = False found["updatedAtMs"] = int(datetime.utcnow().timestamp() * 1000) self.write_cron_store(bot_id, {"version": int(store.get("version", 1) or 1), "jobs": jobs}) return {"status": "stopped", "job_id": job_id} def stop_cron_job_for_bot(self, *, session: Session, bot_id: str, job_id: str) -> Dict[str, Any]: self._require_bot(session=session, bot_id=bot_id) return self.stop_cron_job(bot_id, job_id) def delete_cron_job(self, bot_id: str, job_id: str) -> Dict[str, Any]: store = self.read_cron_store(bot_id) jobs = store.get("jobs", []) if not isinstance(jobs, list): jobs = [] kept = [row for row in jobs if not (isinstance(row, dict) and str(row.get("id")) == job_id)] if len(kept) == len(jobs): raise HTTPException(status_code=404, detail="Cron job not found") self.write_cron_store(bot_id, {"version": int(store.get("version", 1) or 1), "jobs": kept}) return {"status": "deleted", "job_id": job_id} def delete_cron_job_for_bot(self, *, session: Session, bot_id: str, job_id: str) -> Dict[str, Any]: self._require_bot(session=session, bot_id=bot_id) return self.delete_cron_job(bot_id, job_id)