137 lines
6.3 KiB
Python
137 lines
6.3 KiB
Python
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from fastapi import HTTPException
|
|
from sqlmodel import Session
|
|
|
|
from clients.edge.base import EdgeClient
|
|
from models.bot import BotInstance
|
|
from providers.runtime.base import RuntimeProvider
|
|
from providers.target import ProviderTarget, provider_target_to_dict
|
|
|
|
|
|
class EdgeRuntimeProvider(RuntimeProvider):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
read_provider_target: Callable[[str], ProviderTarget],
|
|
resolve_edge_client: Callable[[ProviderTarget], EdgeClient],
|
|
read_runtime_snapshot: Callable[[BotInstance], Dict[str, Any]],
|
|
resolve_env_params: Callable[[str], Dict[str, str]],
|
|
read_bot_channels: Callable[[BotInstance], List[Dict[str, Any]]],
|
|
read_node_metadata: Callable[[str], Dict[str, Any]],
|
|
) -> None:
|
|
self._read_provider_target = read_provider_target
|
|
self._resolve_edge_client = resolve_edge_client
|
|
self._read_runtime_snapshot = read_runtime_snapshot
|
|
self._resolve_env_params = resolve_env_params
|
|
self._read_bot_channels = read_bot_channels
|
|
self._read_node_metadata = read_node_metadata
|
|
|
|
async def start_bot(self, *, session: Session, bot: BotInstance) -> Dict[str, Any]:
|
|
bot_id = str(bot.id or "").strip()
|
|
if not bot_id:
|
|
raise HTTPException(status_code=400, detail="Bot id is required")
|
|
if not bool(getattr(bot, "enabled", True)):
|
|
raise HTTPException(status_code=403, detail="Bot is disabled. Enable it first.")
|
|
runtime_snapshot = self._read_runtime_snapshot(bot)
|
|
target = self._read_provider_target(bot_id)
|
|
client = self._client_for_target(target)
|
|
node_runtime_overrides = self._node_runtime_overrides(target.node_id, target.runtime_kind)
|
|
workspace_runtime = {
|
|
**dict(runtime_snapshot),
|
|
**provider_target_to_dict(target),
|
|
**node_runtime_overrides,
|
|
}
|
|
client.sync_bot_workspace(
|
|
bot_id=bot_id,
|
|
channels_override=self._read_bot_channels(bot),
|
|
global_delivery_override={
|
|
"sendProgress": bool(runtime_snapshot.get("send_progress")),
|
|
"sendToolHints": bool(runtime_snapshot.get("send_tool_hints")),
|
|
},
|
|
runtime_overrides=workspace_runtime,
|
|
)
|
|
result = await client.start_bot(
|
|
bot=bot,
|
|
start_payload={
|
|
"image_tag": bot.image_tag,
|
|
"runtime_kind": target.runtime_kind,
|
|
"env_vars": self._resolve_env_params(bot_id),
|
|
"cpu_cores": runtime_snapshot.get("cpu_cores"),
|
|
"memory_mb": runtime_snapshot.get("memory_mb"),
|
|
"storage_gb": runtime_snapshot.get("storage_gb"),
|
|
**node_runtime_overrides,
|
|
},
|
|
)
|
|
bot.docker_status = "RUNNING"
|
|
session.add(bot)
|
|
session.commit()
|
|
return result
|
|
|
|
def stop_bot(self, *, session: Session, bot: BotInstance) -> Dict[str, Any]:
|
|
bot_id = str(bot.id or "").strip()
|
|
if not bot_id:
|
|
raise HTTPException(status_code=400, detail="Bot id is required")
|
|
if not bool(getattr(bot, "enabled", True)):
|
|
raise HTTPException(status_code=403, detail="Bot is disabled. Enable it first.")
|
|
result = self._client_for_bot(bot_id).stop_bot(bot=bot)
|
|
bot.docker_status = "STOPPED"
|
|
session.add(bot)
|
|
session.commit()
|
|
return result
|
|
|
|
def deliver_command(self, *, bot_id: str, command: str, media: Optional[List[str]] = None) -> Optional[str]:
|
|
return self._client_for_bot(bot_id).deliver_command(bot_id=bot_id, command=command, media=media)
|
|
|
|
def get_recent_logs(self, *, bot_id: str, tail: int = 300) -> List[str]:
|
|
return self._client_for_bot(bot_id).get_recent_logs(bot_id=bot_id, tail=tail)
|
|
|
|
def ensure_monitor(self, *, bot_id: str) -> bool:
|
|
return bool(self._client_for_bot(bot_id).ensure_monitor(bot_id=bot_id))
|
|
|
|
def get_monitor_packets(self, *, bot_id: str, after_seq: int = 0, limit: int = 200) -> List[Dict[str, Any]]:
|
|
return list(self._client_for_bot(bot_id).get_monitor_packets(bot_id=bot_id, after_seq=after_seq, limit=limit) or [])
|
|
|
|
def get_runtime_status(self, *, bot_id: str) -> str:
|
|
return str(self._client_for_bot(bot_id).get_runtime_status(bot_id=bot_id) or "STOPPED").upper()
|
|
|
|
def get_resource_snapshot(self, *, bot_id: str) -> Dict[str, Any]:
|
|
return dict(self._client_for_bot(bot_id).get_resource_snapshot(bot_id=bot_id) or {})
|
|
|
|
def _client_for_bot(self, bot_id: str) -> EdgeClient:
|
|
target = self._read_provider_target(bot_id)
|
|
return self._client_for_target(target)
|
|
|
|
def _client_for_target(self, target: ProviderTarget) -> EdgeClient:
|
|
if target.transport_kind != "edge":
|
|
raise HTTPException(status_code=400, detail=f"edge runtime provider requires edge transport, got {target.transport_kind}")
|
|
return self._resolve_edge_client(target)
|
|
|
|
def _node_runtime_overrides(self, node_id: str, runtime_kind: str) -> Dict[str, str]:
|
|
metadata = dict(self._read_node_metadata(str(node_id or "").strip().lower()) or {})
|
|
payload: Dict[str, str] = {}
|
|
workspace_root = str(metadata.get("workspace_root") or "").strip()
|
|
if workspace_root:
|
|
payload["workspace_root"] = workspace_root
|
|
if str(runtime_kind or "").strip().lower() != "native":
|
|
return payload
|
|
native_sandbox_mode = self._normalize_native_sandbox_mode(metadata.get("native_sandbox_mode"))
|
|
if native_sandbox_mode != "inherit":
|
|
payload["native_sandbox_mode"] = native_sandbox_mode
|
|
native_command = str(metadata.get("native_command") or "").strip()
|
|
native_workdir = str(metadata.get("native_workdir") or "").strip()
|
|
if native_command:
|
|
payload["native_command"] = native_command
|
|
if native_workdir:
|
|
payload["native_workdir"] = native_workdir
|
|
return payload
|
|
|
|
@staticmethod
|
|
def _normalize_native_sandbox_mode(raw_value: Any) -> str:
|
|
text = str(raw_value or "").strip().lower()
|
|
if text in {"workspace", "sandbox", "strict"}:
|
|
return "workspace"
|
|
if text in {"full_access", "full-access", "danger-full-access", "escape"}:
|
|
return "full_access"
|
|
return "inherit"
|