dashboard-nanobot/backend/providers/runtime/edge.py

137 lines
6.3 KiB
Python

from typing import Any, Callable, Dict, List, Optional
from fastapi import HTTPException
from sqlmodel import Session
from clients.edge.base import EdgeClient
from models.bot import BotInstance
from providers.runtime.base import RuntimeProvider
from providers.target import ProviderTarget, provider_target_to_dict
class EdgeRuntimeProvider(RuntimeProvider):
def __init__(
self,
*,
read_provider_target: Callable[[str], ProviderTarget],
resolve_edge_client: Callable[[ProviderTarget], EdgeClient],
read_runtime_snapshot: Callable[[BotInstance], Dict[str, Any]],
resolve_env_params: Callable[[str], Dict[str, str]],
read_bot_channels: Callable[[BotInstance], List[Dict[str, Any]]],
read_node_metadata: Callable[[str], Dict[str, Any]],
) -> None:
self._read_provider_target = read_provider_target
self._resolve_edge_client = resolve_edge_client
self._read_runtime_snapshot = read_runtime_snapshot
self._resolve_env_params = resolve_env_params
self._read_bot_channels = read_bot_channels
self._read_node_metadata = read_node_metadata
async def start_bot(self, *, session: Session, bot: BotInstance) -> Dict[str, Any]:
bot_id = str(bot.id or "").strip()
if not bot_id:
raise HTTPException(status_code=400, detail="Bot id is required")
if not bool(getattr(bot, "enabled", True)):
raise HTTPException(status_code=403, detail="Bot is disabled. Enable it first.")
runtime_snapshot = self._read_runtime_snapshot(bot)
target = self._read_provider_target(bot_id)
client = self._client_for_target(target)
node_runtime_overrides = self._node_runtime_overrides(target.node_id, target.runtime_kind)
workspace_runtime = {
**dict(runtime_snapshot),
**provider_target_to_dict(target),
**node_runtime_overrides,
}
client.sync_bot_workspace(
bot_id=bot_id,
channels_override=self._read_bot_channels(bot),
global_delivery_override={
"sendProgress": bool(runtime_snapshot.get("send_progress")),
"sendToolHints": bool(runtime_snapshot.get("send_tool_hints")),
},
runtime_overrides=workspace_runtime,
)
result = await client.start_bot(
bot=bot,
start_payload={
"image_tag": bot.image_tag,
"runtime_kind": target.runtime_kind,
"env_vars": self._resolve_env_params(bot_id),
"cpu_cores": runtime_snapshot.get("cpu_cores"),
"memory_mb": runtime_snapshot.get("memory_mb"),
"storage_gb": runtime_snapshot.get("storage_gb"),
**node_runtime_overrides,
},
)
bot.docker_status = "RUNNING"
session.add(bot)
session.commit()
return result
def stop_bot(self, *, session: Session, bot: BotInstance) -> Dict[str, Any]:
bot_id = str(bot.id or "").strip()
if not bot_id:
raise HTTPException(status_code=400, detail="Bot id is required")
if not bool(getattr(bot, "enabled", True)):
raise HTTPException(status_code=403, detail="Bot is disabled. Enable it first.")
result = self._client_for_bot(bot_id).stop_bot(bot=bot)
bot.docker_status = "STOPPED"
session.add(bot)
session.commit()
return result
def deliver_command(self, *, bot_id: str, command: str, media: Optional[List[str]] = None) -> Optional[str]:
return self._client_for_bot(bot_id).deliver_command(bot_id=bot_id, command=command, media=media)
def get_recent_logs(self, *, bot_id: str, tail: int = 300) -> List[str]:
return self._client_for_bot(bot_id).get_recent_logs(bot_id=bot_id, tail=tail)
def ensure_monitor(self, *, bot_id: str) -> bool:
return bool(self._client_for_bot(bot_id).ensure_monitor(bot_id=bot_id))
def get_monitor_packets(self, *, bot_id: str, after_seq: int = 0, limit: int = 200) -> List[Dict[str, Any]]:
return list(self._client_for_bot(bot_id).get_monitor_packets(bot_id=bot_id, after_seq=after_seq, limit=limit) or [])
def get_runtime_status(self, *, bot_id: str) -> str:
return str(self._client_for_bot(bot_id).get_runtime_status(bot_id=bot_id) or "STOPPED").upper()
def get_resource_snapshot(self, *, bot_id: str) -> Dict[str, Any]:
return dict(self._client_for_bot(bot_id).get_resource_snapshot(bot_id=bot_id) or {})
def _client_for_bot(self, bot_id: str) -> EdgeClient:
target = self._read_provider_target(bot_id)
return self._client_for_target(target)
def _client_for_target(self, target: ProviderTarget) -> EdgeClient:
if target.transport_kind != "edge":
raise HTTPException(status_code=400, detail=f"edge runtime provider requires edge transport, got {target.transport_kind}")
return self._resolve_edge_client(target)
def _node_runtime_overrides(self, node_id: str, runtime_kind: str) -> Dict[str, str]:
metadata = dict(self._read_node_metadata(str(node_id or "").strip().lower()) or {})
payload: Dict[str, str] = {}
workspace_root = str(metadata.get("workspace_root") or "").strip()
if workspace_root:
payload["workspace_root"] = workspace_root
if str(runtime_kind or "").strip().lower() != "native":
return payload
native_sandbox_mode = self._normalize_native_sandbox_mode(metadata.get("native_sandbox_mode"))
if native_sandbox_mode != "inherit":
payload["native_sandbox_mode"] = native_sandbox_mode
native_command = str(metadata.get("native_command") or "").strip()
native_workdir = str(metadata.get("native_workdir") or "").strip()
if native_command:
payload["native_command"] = native_command
if native_workdir:
payload["native_workdir"] = native_workdir
return payload
@staticmethod
def _normalize_native_sandbox_mode(raw_value: Any) -> str:
text = str(raw_value or "").strip().lower()
if text in {"workspace", "sandbox", "strict"}:
return "workspace"
if text in {"full_access", "full-access", "danger-full-access", "escape"}:
return "full_access"
return "inherit"