dashboard-nanobot/backend/services/runtime_event_service.py

290 lines
12 KiB
Python

import asyncio
import json
import os
from datetime import datetime, timedelta, timezone
from typing import Any, Callable, Dict, List, Optional
from fastapi import HTTPException, WebSocket
from sqlmodel import Session
from models.bot import BotInstance, BotMessage
class WSConnectionManager:
def __init__(self) -> None:
self.connections: Dict[str, List[WebSocket]] = {}
async def connect(self, bot_id: str, websocket: WebSocket):
await websocket.accept()
self.connections.setdefault(bot_id, []).append(websocket)
def disconnect(self, bot_id: str, websocket: WebSocket):
conns = self.connections.get(bot_id, [])
if websocket in conns:
conns.remove(websocket)
if not conns and bot_id in self.connections:
del self.connections[bot_id]
async def broadcast(self, bot_id: str, data: Dict[str, Any]):
conns = list(self.connections.get(bot_id, []))
for ws in conns:
try:
await ws.send_json(data)
except Exception:
self.disconnect(bot_id, ws)
class RuntimeEventService:
def __init__(
self,
*,
app: Any,
engine: Any,
cache: Any,
logger: Any,
publish_runtime_topic_packet: Callable[..., None],
bind_usage_message: Callable[..., None],
finalize_usage_from_packet: Callable[..., Any],
workspace_root: Callable[[str], str],
parse_message_media: Callable[[str, Optional[str]], List[str]],
) -> None:
self._app = app
self._engine = engine
self._cache = cache
self._logger = logger
self._publish_runtime_topic_packet = publish_runtime_topic_packet
self._bind_usage_message = bind_usage_message
self._finalize_usage_from_packet = finalize_usage_from_packet
self._workspace_root = workspace_root
self._parse_message_media = parse_message_media
self.manager = WSConnectionManager()
@staticmethod
def cache_key_bots_list(user_id: Optional[int] = None) -> str:
normalized_user_id = int(user_id or 0)
return f"bots:list:user:{normalized_user_id}"
@staticmethod
def cache_key_bot_detail(bot_id: str) -> str:
return f"bot:detail:{bot_id}"
@staticmethod
def cache_key_bot_messages(bot_id: str, limit: int) -> str:
return f"bot:messages:v2:{bot_id}:limit:{limit}"
@staticmethod
def cache_key_bot_messages_page(bot_id: str, limit: int, before_id: Optional[int]) -> str:
cursor = str(int(before_id)) if isinstance(before_id, int) and before_id > 0 else "latest"
return f"bot:messages:page:v2:{bot_id}:before:{cursor}:limit:{limit}"
@staticmethod
def cache_key_images() -> str:
return "images:list"
def invalidate_bot_detail_cache(self, bot_id: str) -> None:
self._cache.delete(self.cache_key_bot_detail(bot_id))
self._cache.delete_prefix("bots:list:user:")
def invalidate_bot_messages_cache(self, bot_id: str) -> None:
self._cache.delete_prefix(f"bot:messages:{bot_id}:")
def invalidate_images_cache(self) -> None:
self._cache.delete(self.cache_key_images())
@staticmethod
def normalize_last_action_text(value: Any) -> str:
text = str(value or "").replace("\r\n", "\n").replace("\r", "\n").strip()
if not text:
return ""
text = __import__("re").sub(r"\n{4,}", "\n\n\n", text)
return text[:16000]
@staticmethod
def normalize_packet_channel(packet: Dict[str, Any]) -> str:
raw = str(packet.get("channel") or packet.get("source") or "").strip().lower()
if raw in {"dashboard", "dashboard_channel", "dashboard-channel"}:
return "dashboard"
return raw
def normalize_media_item(self, bot_id: str, value: Any) -> str:
raw = str(value or "").strip().replace("\\", "/")
if not raw:
return ""
if raw.startswith("/root/.nanobot/workspace/"):
return raw[len("/root/.nanobot/workspace/") :].lstrip("/")
root = self._workspace_root(bot_id)
if os.path.isabs(raw):
try:
if os.path.commonpath([root, raw]) == root:
return os.path.relpath(raw, root).replace("\\", "/")
except Exception:
pass
return raw.lstrip("/")
def normalize_media_list(self, raw: Any, bot_id: str) -> List[str]:
if not isinstance(raw, list):
return []
rows: List[str] = []
for value in raw:
normalized = self.normalize_media_item(bot_id, value)
if normalized:
rows.append(normalized)
return rows
def serialize_bot_message_row(self, bot_id: str, row: BotMessage) -> Dict[str, Any]:
created_at = row.created_at
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)
return {
"id": row.id,
"bot_id": row.bot_id,
"role": row.role,
"text": row.text,
"media": self._parse_message_media(bot_id, getattr(row, "media_json", None)),
"feedback": str(getattr(row, "feedback", "") or "").strip() or None,
"ts": int(created_at.timestamp() * 1000),
}
@staticmethod
def resolve_local_day_range(date_text: str, tz_offset_minutes: Optional[int]) -> tuple[datetime, datetime]:
try:
local_day = datetime.strptime(str(date_text or "").strip(), "%Y-%m-%d")
except ValueError as exc:
raise HTTPException(status_code=400, detail="Invalid date, expected YYYY-MM-DD") from exc
offset_minutes = 0
if tz_offset_minutes is not None:
try:
offset_minutes = int(tz_offset_minutes)
except (TypeError, ValueError) as exc:
raise HTTPException(status_code=400, detail="Invalid timezone offset") from exc
utc_start = local_day + timedelta(minutes=offset_minutes)
utc_end = utc_start + timedelta(days=1)
return utc_start, utc_end
def persist_runtime_packet(self, bot_id: str, packet: Dict[str, Any]) -> Optional[int]:
packet_type = str(packet.get("type", "")).upper()
if packet_type not in {"AGENT_STATE", "ASSISTANT_MESSAGE", "USER_COMMAND", "BUS_EVENT"}:
return None
source_channel = self.normalize_packet_channel(packet)
if source_channel != "dashboard":
return None
persisted_message_id: Optional[int] = None
with Session(self._engine) as session:
bot = session.get(BotInstance, bot_id)
if not bot:
return None
if packet_type == "AGENT_STATE":
payload = packet.get("payload") or {}
state = str(payload.get("state") or "").strip()
action = self.normalize_last_action_text(payload.get("action_msg") or payload.get("msg") or "")
if state:
bot.current_state = state
if action:
bot.last_action = action
elif packet_type == "ASSISTANT_MESSAGE":
bot.current_state = "IDLE"
text_msg = str(packet.get("text") or "").strip()
media_list = self.normalize_media_list(packet.get("media"), bot_id)
if text_msg or media_list:
if text_msg:
bot.last_action = self.normalize_last_action_text(text_msg)
message_row = BotMessage(
bot_id=bot_id,
role="assistant",
text=text_msg,
media_json=json.dumps(media_list, ensure_ascii=False) if media_list else None,
)
session.add(message_row)
session.flush()
persisted_message_id = message_row.id
self._finalize_usage_from_packet(
session,
bot_id,
{
**packet,
"message_id": persisted_message_id,
},
)
elif packet_type == "USER_COMMAND":
text_msg = str(packet.get("text") or "").strip()
media_list = self.normalize_media_list(packet.get("media"), bot_id)
if text_msg or media_list:
message_row = BotMessage(
bot_id=bot_id,
role="user",
text=text_msg,
media_json=json.dumps(media_list, ensure_ascii=False) if media_list else None,
)
session.add(message_row)
session.flush()
persisted_message_id = message_row.id
self._bind_usage_message(
session,
bot_id,
str(packet.get("request_id") or "").strip(),
persisted_message_id,
)
elif packet_type == "BUS_EVENT":
is_progress = bool(packet.get("is_progress"))
detail_text = str(packet.get("content") or packet.get("text") or "").strip()
if not is_progress:
text_msg = detail_text
media_list = self.normalize_media_list(packet.get("media"), bot_id)
if text_msg or media_list:
bot.current_state = "IDLE"
if text_msg:
bot.last_action = self.normalize_last_action_text(text_msg)
message_row = BotMessage(
bot_id=bot_id,
role="assistant",
text=text_msg,
media_json=json.dumps(media_list, ensure_ascii=False) if media_list else None,
)
session.add(message_row)
session.flush()
persisted_message_id = message_row.id
self._finalize_usage_from_packet(
session,
bot_id,
{
"text": text_msg,
"usage": packet.get("usage"),
"request_id": packet.get("request_id"),
"provider": packet.get("provider"),
"model": packet.get("model"),
"message_id": persisted_message_id,
},
)
bot.updated_at = datetime.utcnow()
session.add(bot)
session.commit()
self._publish_runtime_topic_packet(
self._engine,
bot_id,
packet,
source_channel,
persisted_message_id,
self._logger,
)
if persisted_message_id:
packet["message_id"] = persisted_message_id
if packet_type in {"ASSISTANT_MESSAGE", "USER_COMMAND", "BUS_EVENT"}:
self.invalidate_bot_messages_cache(bot_id)
self.invalidate_bot_detail_cache(bot_id)
return persisted_message_id
def broadcast_runtime_packet(self, bot_id: str, packet: Dict[str, Any], loop: Any) -> None:
asyncio.run_coroutine_threadsafe(self.manager.broadcast(bot_id, packet), loop)
def docker_callback(self, bot_id: str, packet: Dict[str, Any]):
self.persist_runtime_packet(bot_id, packet)
loop = getattr(self._app.state, "main_loop", None)
if not loop or not loop.is_running():
return
asyncio.run_coroutine_threadsafe(self.manager.broadcast(bot_id, packet), loop)