321 lines
13 KiB
Python
321 lines
13 KiB
Python
import json
|
|
import os
|
|
import re
|
|
from datetime import datetime
|
|
from typing import Any, Callable, Dict, List
|
|
|
|
from fastapi import HTTPException
|
|
from sqlmodel import Session
|
|
|
|
from models.bot import BotInstance
|
|
|
|
|
|
ReadEdgeStateData = Callable[..., Dict[str, Any]]
|
|
WriteEdgeStateData = Callable[..., bool]
|
|
ReadBotConfig = Callable[[str], Dict[str, Any]]
|
|
WriteBotConfig = Callable[[str, Dict[str, Any]], None]
|
|
InvalidateBotCache = Callable[[str], None]
|
|
PathResolver = Callable[[str], str]
|
|
NormalizeEnvParams = Callable[[Any], Dict[str, str]]
|
|
|
|
|
|
class BotConfigStateService:
|
|
_MCP_SERVER_NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,64}$")
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
read_edge_state_data: ReadEdgeStateData,
|
|
write_edge_state_data: WriteEdgeStateData,
|
|
read_bot_config: ReadBotConfig,
|
|
write_bot_config: WriteBotConfig,
|
|
invalidate_bot_detail_cache: InvalidateBotCache,
|
|
env_store_path: PathResolver,
|
|
cron_store_path: PathResolver,
|
|
normalize_env_params: NormalizeEnvParams,
|
|
) -> None:
|
|
self._read_edge_state_data = read_edge_state_data
|
|
self._write_edge_state_data = write_edge_state_data
|
|
self._read_bot_config = read_bot_config
|
|
self._write_bot_config = write_bot_config
|
|
self._invalidate_bot_detail_cache = invalidate_bot_detail_cache
|
|
self._env_store_path = env_store_path
|
|
self._cron_store_path = cron_store_path
|
|
self._normalize_env_params = normalize_env_params
|
|
|
|
def _require_bot(self, *, session: Session, bot_id: str) -> BotInstance:
|
|
bot = session.get(BotInstance, bot_id)
|
|
if not bot:
|
|
raise HTTPException(status_code=404, detail="Bot not found")
|
|
return bot
|
|
|
|
def read_env_store(self, bot_id: str) -> Dict[str, str]:
|
|
data = self._read_edge_state_data(bot_id=bot_id, state_key="env", default_payload={})
|
|
if data:
|
|
return self._normalize_env_params(data)
|
|
|
|
path = self._env_store_path(bot_id)
|
|
if not os.path.isfile(path):
|
|
return {}
|
|
try:
|
|
with open(path, "r", encoding="utf-8") as file:
|
|
payload = json.load(file)
|
|
return self._normalize_env_params(payload)
|
|
except Exception:
|
|
return {}
|
|
|
|
def write_env_store(self, bot_id: str, env_params: Dict[str, str]) -> None:
|
|
normalized_env = self._normalize_env_params(env_params)
|
|
if self._write_edge_state_data(bot_id=bot_id, state_key="env", data=normalized_env):
|
|
return
|
|
path = self._env_store_path(bot_id)
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
tmp_path = f"{path}.tmp"
|
|
with open(tmp_path, "w", encoding="utf-8") as file:
|
|
json.dump(normalized_env, file, ensure_ascii=False, indent=2)
|
|
os.replace(tmp_path, path)
|
|
|
|
def get_env_params(self, bot_id: str) -> Dict[str, Any]:
|
|
return {
|
|
"bot_id": bot_id,
|
|
"env_params": self.read_env_store(bot_id),
|
|
}
|
|
|
|
def get_env_params_for_bot(self, *, session: Session, bot_id: str) -> Dict[str, Any]:
|
|
self._require_bot(session=session, bot_id=bot_id)
|
|
return self.get_env_params(bot_id)
|
|
|
|
def update_env_params(self, bot_id: str, env_params: Any) -> Dict[str, Any]:
|
|
normalized = self._normalize_env_params(env_params)
|
|
self.write_env_store(bot_id, normalized)
|
|
self._invalidate_bot_detail_cache(bot_id)
|
|
return {
|
|
"status": "updated",
|
|
"bot_id": bot_id,
|
|
"env_params": normalized,
|
|
"restart_required": True,
|
|
}
|
|
|
|
def update_env_params_for_bot(self, *, session: Session, bot_id: str, env_params: Any) -> Dict[str, Any]:
|
|
self._require_bot(session=session, bot_id=bot_id)
|
|
return self.update_env_params(bot_id, env_params)
|
|
|
|
def normalize_mcp_servers(self, raw: Any) -> Dict[str, Dict[str, Any]]:
|
|
if not isinstance(raw, dict):
|
|
return {}
|
|
rows: Dict[str, Dict[str, Any]] = {}
|
|
for server_name, server_cfg in raw.items():
|
|
name = str(server_name or "").strip()
|
|
if not name or not self._MCP_SERVER_NAME_RE.fullmatch(name):
|
|
continue
|
|
if not isinstance(server_cfg, dict):
|
|
continue
|
|
|
|
url = str(server_cfg.get("url") or "").strip()
|
|
if not url:
|
|
continue
|
|
|
|
transport_type = str(server_cfg.get("type") or "streamableHttp").strip()
|
|
if transport_type not in {"streamableHttp", "sse"}:
|
|
transport_type = "streamableHttp"
|
|
|
|
headers_raw = server_cfg.get("headers")
|
|
headers: Dict[str, str] = {}
|
|
if isinstance(headers_raw, dict):
|
|
for key, value in headers_raw.items():
|
|
header_key = str(key or "").strip()
|
|
if not header_key:
|
|
continue
|
|
headers[header_key] = str(value or "").strip()
|
|
|
|
timeout_raw = server_cfg.get("toolTimeout", 60)
|
|
try:
|
|
timeout = int(timeout_raw)
|
|
except Exception:
|
|
timeout = 60
|
|
timeout = max(1, min(timeout, 600))
|
|
|
|
rows[name] = {
|
|
"type": transport_type,
|
|
"url": url,
|
|
"headers": headers,
|
|
"toolTimeout": timeout,
|
|
}
|
|
return rows
|
|
|
|
def _merge_mcp_servers_preserving_extras(
|
|
self,
|
|
current_raw: Any,
|
|
normalized: Dict[str, Dict[str, Any]],
|
|
) -> Dict[str, Dict[str, Any]]:
|
|
current_map = current_raw if isinstance(current_raw, dict) else {}
|
|
merged: Dict[str, Dict[str, Any]] = {}
|
|
for name, normalized_cfg in normalized.items():
|
|
base = current_map.get(name)
|
|
base_cfg = dict(base) if isinstance(base, dict) else {}
|
|
next_cfg = dict(base_cfg)
|
|
next_cfg.update(normalized_cfg)
|
|
merged[name] = next_cfg
|
|
return merged
|
|
|
|
def _sanitize_mcp_servers_in_config_data(self, config_data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
|
if not isinstance(config_data, dict):
|
|
return {}
|
|
tools_cfg = config_data.get("tools")
|
|
if not isinstance(tools_cfg, dict):
|
|
tools_cfg = {}
|
|
current_raw = tools_cfg.get("mcpServers")
|
|
normalized = self.normalize_mcp_servers(current_raw)
|
|
merged = self._merge_mcp_servers_preserving_extras(current_raw, normalized)
|
|
tools_cfg["mcpServers"] = merged
|
|
config_data["tools"] = tools_cfg
|
|
return merged
|
|
|
|
def get_mcp_config(self, bot_id: str) -> Dict[str, Any]:
|
|
config_data = self._read_bot_config(bot_id)
|
|
tools_cfg = config_data.get("tools") if isinstance(config_data, dict) else {}
|
|
if not isinstance(tools_cfg, dict):
|
|
tools_cfg = {}
|
|
mcp_servers = self.normalize_mcp_servers(tools_cfg.get("mcpServers"))
|
|
return {
|
|
"bot_id": bot_id,
|
|
"mcp_servers": mcp_servers,
|
|
"locked_servers": [],
|
|
"restart_required": True,
|
|
}
|
|
|
|
def get_mcp_config_for_bot(self, *, session: Session, bot_id: str) -> Dict[str, Any]:
|
|
self._require_bot(session=session, bot_id=bot_id)
|
|
return self.get_mcp_config(bot_id)
|
|
|
|
def update_mcp_config(self, bot_id: str, mcp_servers: Any) -> Dict[str, Any]:
|
|
config_data = self._read_bot_config(bot_id)
|
|
if not isinstance(config_data, dict):
|
|
config_data = {}
|
|
tools_cfg = config_data.get("tools")
|
|
if not isinstance(tools_cfg, dict):
|
|
tools_cfg = {}
|
|
normalized_mcp_servers = self.normalize_mcp_servers(mcp_servers or {})
|
|
current_mcp_servers = tools_cfg.get("mcpServers")
|
|
merged_mcp_servers = self._merge_mcp_servers_preserving_extras(current_mcp_servers, normalized_mcp_servers)
|
|
tools_cfg["mcpServers"] = merged_mcp_servers
|
|
config_data["tools"] = tools_cfg
|
|
sanitized_after_save = self._sanitize_mcp_servers_in_config_data(config_data)
|
|
self._write_bot_config(bot_id, config_data)
|
|
self._invalidate_bot_detail_cache(bot_id)
|
|
return {
|
|
"status": "updated",
|
|
"bot_id": bot_id,
|
|
"mcp_servers": self.normalize_mcp_servers(sanitized_after_save),
|
|
"locked_servers": [],
|
|
"restart_required": True,
|
|
}
|
|
|
|
def update_mcp_config_for_bot(self, *, session: Session, bot_id: str, mcp_servers: Any) -> Dict[str, Any]:
|
|
self._require_bot(session=session, bot_id=bot_id)
|
|
return self.update_mcp_config(bot_id, mcp_servers)
|
|
|
|
def read_cron_store(self, bot_id: str) -> Dict[str, Any]:
|
|
data = self._read_edge_state_data(
|
|
bot_id=bot_id,
|
|
state_key="cron",
|
|
default_payload={"version": 1, "jobs": []},
|
|
)
|
|
if isinstance(data, dict) and data:
|
|
jobs = data.get("jobs")
|
|
if not isinstance(jobs, list):
|
|
jobs = []
|
|
try:
|
|
version = int(data.get("version", 1) or 1)
|
|
except Exception:
|
|
version = 1
|
|
return {"version": max(1, version), "jobs": jobs}
|
|
|
|
path = self._cron_store_path(bot_id)
|
|
if not os.path.isfile(path):
|
|
return {"version": 1, "jobs": []}
|
|
try:
|
|
with open(path, "r", encoding="utf-8") as file:
|
|
payload = json.load(file)
|
|
if not isinstance(payload, dict):
|
|
return {"version": 1, "jobs": []}
|
|
jobs = payload.get("jobs")
|
|
if not isinstance(jobs, list):
|
|
payload["jobs"] = []
|
|
if "version" not in payload:
|
|
payload["version"] = 1
|
|
return payload
|
|
except Exception:
|
|
return {"version": 1, "jobs": []}
|
|
|
|
def write_cron_store(self, bot_id: str, store: Dict[str, Any]) -> None:
|
|
normalized_store = dict(store if isinstance(store, dict) else {})
|
|
jobs = normalized_store.get("jobs")
|
|
if not isinstance(jobs, list):
|
|
normalized_store["jobs"] = []
|
|
try:
|
|
normalized_store["version"] = max(1, int(normalized_store.get("version", 1) or 1))
|
|
except Exception:
|
|
normalized_store["version"] = 1
|
|
if self._write_edge_state_data(bot_id=bot_id, state_key="cron", data=normalized_store):
|
|
return
|
|
path = self._cron_store_path(bot_id)
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
tmp_path = f"{path}.tmp"
|
|
with open(tmp_path, "w", encoding="utf-8") as file:
|
|
json.dump(normalized_store, file, ensure_ascii=False, indent=2)
|
|
os.replace(tmp_path, path)
|
|
|
|
def list_cron_jobs(self, bot_id: str, include_disabled: bool = True) -> Dict[str, Any]:
|
|
store = self.read_cron_store(bot_id)
|
|
rows = []
|
|
for row in store.get("jobs", []):
|
|
if not isinstance(row, dict):
|
|
continue
|
|
enabled = bool(row.get("enabled", True))
|
|
if not include_disabled and not enabled:
|
|
continue
|
|
rows.append(row)
|
|
rows.sort(key=lambda value: int(((value.get("state") or {}).get("nextRunAtMs")) or 2**62))
|
|
return {"bot_id": bot_id, "version": int(store.get("version", 1) or 1), "jobs": rows}
|
|
|
|
def list_cron_jobs_for_bot(self, *, session: Session, bot_id: str, include_disabled: bool = True) -> Dict[str, Any]:
|
|
self._require_bot(session=session, bot_id=bot_id)
|
|
return self.list_cron_jobs(bot_id, include_disabled=include_disabled)
|
|
|
|
def stop_cron_job(self, bot_id: str, job_id: str) -> Dict[str, Any]:
|
|
store = self.read_cron_store(bot_id)
|
|
jobs = store.get("jobs", [])
|
|
if not isinstance(jobs, list):
|
|
jobs = []
|
|
found = None
|
|
for row in jobs:
|
|
if isinstance(row, dict) and str(row.get("id")) == job_id:
|
|
found = row
|
|
break
|
|
if not found:
|
|
raise HTTPException(status_code=404, detail="Cron job not found")
|
|
found["enabled"] = False
|
|
found["updatedAtMs"] = int(datetime.utcnow().timestamp() * 1000)
|
|
self.write_cron_store(bot_id, {"version": int(store.get("version", 1) or 1), "jobs": jobs})
|
|
return {"status": "stopped", "job_id": job_id}
|
|
|
|
def stop_cron_job_for_bot(self, *, session: Session, bot_id: str, job_id: str) -> Dict[str, Any]:
|
|
self._require_bot(session=session, bot_id=bot_id)
|
|
return self.stop_cron_job(bot_id, job_id)
|
|
|
|
def delete_cron_job(self, bot_id: str, job_id: str) -> Dict[str, Any]:
|
|
store = self.read_cron_store(bot_id)
|
|
jobs = store.get("jobs", [])
|
|
if not isinstance(jobs, list):
|
|
jobs = []
|
|
kept = [row for row in jobs if not (isinstance(row, dict) and str(row.get("id")) == job_id)]
|
|
if len(kept) == len(jobs):
|
|
raise HTTPException(status_code=404, detail="Cron job not found")
|
|
self.write_cron_store(bot_id, {"version": int(store.get("version", 1) or 1), "jobs": kept})
|
|
return {"status": "deleted", "job_id": job_id}
|
|
|
|
def delete_cron_job_for_bot(self, *, session: Session, bot_id: str, job_id: str) -> Dict[str, Any]:
|
|
self._require_bot(session=session, bot_id=bot_id)
|
|
return self.delete_cron_job(bot_id, job_id)
|