290 lines
12 KiB
Python
290 lines
12 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from fastapi import HTTPException, WebSocket
|
|
from sqlmodel import Session
|
|
|
|
from models.bot import BotInstance, BotMessage
|
|
|
|
|
|
class WSConnectionManager:
|
|
def __init__(self) -> None:
|
|
self.connections: Dict[str, List[WebSocket]] = {}
|
|
|
|
async def connect(self, bot_id: str, websocket: WebSocket):
|
|
await websocket.accept()
|
|
self.connections.setdefault(bot_id, []).append(websocket)
|
|
|
|
def disconnect(self, bot_id: str, websocket: WebSocket):
|
|
conns = self.connections.get(bot_id, [])
|
|
if websocket in conns:
|
|
conns.remove(websocket)
|
|
if not conns and bot_id in self.connections:
|
|
del self.connections[bot_id]
|
|
|
|
async def broadcast(self, bot_id: str, data: Dict[str, Any]):
|
|
conns = list(self.connections.get(bot_id, []))
|
|
for ws in conns:
|
|
try:
|
|
await ws.send_json(data)
|
|
except Exception:
|
|
self.disconnect(bot_id, ws)
|
|
|
|
|
|
class RuntimeEventService:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
app: Any,
|
|
engine: Any,
|
|
cache: Any,
|
|
logger: Any,
|
|
publish_runtime_topic_packet: Callable[..., None],
|
|
bind_usage_message: Callable[..., None],
|
|
finalize_usage_from_packet: Callable[..., Any],
|
|
workspace_root: Callable[[str], str],
|
|
parse_message_media: Callable[[str, Optional[str]], List[str]],
|
|
) -> None:
|
|
self._app = app
|
|
self._engine = engine
|
|
self._cache = cache
|
|
self._logger = logger
|
|
self._publish_runtime_topic_packet = publish_runtime_topic_packet
|
|
self._bind_usage_message = bind_usage_message
|
|
self._finalize_usage_from_packet = finalize_usage_from_packet
|
|
self._workspace_root = workspace_root
|
|
self._parse_message_media = parse_message_media
|
|
self.manager = WSConnectionManager()
|
|
|
|
@staticmethod
|
|
def cache_key_bots_list(user_id: Optional[int] = None) -> str:
|
|
normalized_user_id = int(user_id or 0)
|
|
return f"bots:list:user:{normalized_user_id}"
|
|
|
|
@staticmethod
|
|
def cache_key_bot_detail(bot_id: str) -> str:
|
|
return f"bot:detail:{bot_id}"
|
|
|
|
@staticmethod
|
|
def cache_key_bot_messages(bot_id: str, limit: int) -> str:
|
|
return f"bot:messages:v2:{bot_id}:limit:{limit}"
|
|
|
|
@staticmethod
|
|
def cache_key_bot_messages_page(bot_id: str, limit: int, before_id: Optional[int]) -> str:
|
|
cursor = str(int(before_id)) if isinstance(before_id, int) and before_id > 0 else "latest"
|
|
return f"bot:messages:page:v2:{bot_id}:before:{cursor}:limit:{limit}"
|
|
|
|
@staticmethod
|
|
def cache_key_images() -> str:
|
|
return "images:list"
|
|
|
|
def invalidate_bot_detail_cache(self, bot_id: str) -> None:
|
|
self._cache.delete(self.cache_key_bot_detail(bot_id))
|
|
self._cache.delete_prefix("bots:list:user:")
|
|
|
|
def invalidate_bot_messages_cache(self, bot_id: str) -> None:
|
|
self._cache.delete_prefix(f"bot:messages:{bot_id}:")
|
|
|
|
def invalidate_images_cache(self) -> None:
|
|
self._cache.delete(self.cache_key_images())
|
|
|
|
@staticmethod
|
|
def normalize_last_action_text(value: Any) -> str:
|
|
text = str(value or "").replace("\r\n", "\n").replace("\r", "\n").strip()
|
|
if not text:
|
|
return ""
|
|
text = __import__("re").sub(r"\n{4,}", "\n\n\n", text)
|
|
return text[:16000]
|
|
|
|
@staticmethod
|
|
def normalize_packet_channel(packet: Dict[str, Any]) -> str:
|
|
raw = str(packet.get("channel") or packet.get("source") or "").strip().lower()
|
|
if raw in {"dashboard", "dashboard_channel", "dashboard-channel"}:
|
|
return "dashboard"
|
|
return raw
|
|
|
|
def normalize_media_item(self, bot_id: str, value: Any) -> str:
|
|
raw = str(value or "").strip().replace("\\", "/")
|
|
if not raw:
|
|
return ""
|
|
if raw.startswith("/root/.nanobot/workspace/"):
|
|
return raw[len("/root/.nanobot/workspace/") :].lstrip("/")
|
|
root = self._workspace_root(bot_id)
|
|
if os.path.isabs(raw):
|
|
try:
|
|
if os.path.commonpath([root, raw]) == root:
|
|
return os.path.relpath(raw, root).replace("\\", "/")
|
|
except Exception:
|
|
pass
|
|
return raw.lstrip("/")
|
|
|
|
def normalize_media_list(self, raw: Any, bot_id: str) -> List[str]:
|
|
if not isinstance(raw, list):
|
|
return []
|
|
rows: List[str] = []
|
|
for value in raw:
|
|
normalized = self.normalize_media_item(bot_id, value)
|
|
if normalized:
|
|
rows.append(normalized)
|
|
return rows
|
|
|
|
def serialize_bot_message_row(self, bot_id: str, row: BotMessage) -> Dict[str, Any]:
|
|
created_at = row.created_at
|
|
if created_at.tzinfo is None:
|
|
created_at = created_at.replace(tzinfo=timezone.utc)
|
|
return {
|
|
"id": row.id,
|
|
"bot_id": row.bot_id,
|
|
"role": row.role,
|
|
"text": row.text,
|
|
"media": self._parse_message_media(bot_id, getattr(row, "media_json", None)),
|
|
"feedback": str(getattr(row, "feedback", "") or "").strip() or None,
|
|
"ts": int(created_at.timestamp() * 1000),
|
|
}
|
|
|
|
@staticmethod
|
|
def resolve_local_day_range(date_text: str, tz_offset_minutes: Optional[int]) -> tuple[datetime, datetime]:
|
|
try:
|
|
local_day = datetime.strptime(str(date_text or "").strip(), "%Y-%m-%d")
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail="Invalid date, expected YYYY-MM-DD") from exc
|
|
|
|
offset_minutes = 0
|
|
if tz_offset_minutes is not None:
|
|
try:
|
|
offset_minutes = int(tz_offset_minutes)
|
|
except (TypeError, ValueError) as exc:
|
|
raise HTTPException(status_code=400, detail="Invalid timezone offset") from exc
|
|
|
|
utc_start = local_day + timedelta(minutes=offset_minutes)
|
|
utc_end = utc_start + timedelta(days=1)
|
|
return utc_start, utc_end
|
|
|
|
def persist_runtime_packet(self, bot_id: str, packet: Dict[str, Any]) -> Optional[int]:
|
|
packet_type = str(packet.get("type", "")).upper()
|
|
if packet_type not in {"AGENT_STATE", "ASSISTANT_MESSAGE", "USER_COMMAND", "BUS_EVENT"}:
|
|
return None
|
|
source_channel = self.normalize_packet_channel(packet)
|
|
if source_channel != "dashboard":
|
|
return None
|
|
persisted_message_id: Optional[int] = None
|
|
with Session(self._engine) as session:
|
|
bot = session.get(BotInstance, bot_id)
|
|
if not bot:
|
|
return None
|
|
if packet_type == "AGENT_STATE":
|
|
payload = packet.get("payload") or {}
|
|
state = str(payload.get("state") or "").strip()
|
|
action = self.normalize_last_action_text(payload.get("action_msg") or payload.get("msg") or "")
|
|
if state:
|
|
bot.current_state = state
|
|
if action:
|
|
bot.last_action = action
|
|
elif packet_type == "ASSISTANT_MESSAGE":
|
|
bot.current_state = "IDLE"
|
|
text_msg = str(packet.get("text") or "").strip()
|
|
media_list = self.normalize_media_list(packet.get("media"), bot_id)
|
|
if text_msg or media_list:
|
|
if text_msg:
|
|
bot.last_action = self.normalize_last_action_text(text_msg)
|
|
message_row = BotMessage(
|
|
bot_id=bot_id,
|
|
role="assistant",
|
|
text=text_msg,
|
|
media_json=json.dumps(media_list, ensure_ascii=False) if media_list else None,
|
|
)
|
|
session.add(message_row)
|
|
session.flush()
|
|
persisted_message_id = message_row.id
|
|
self._finalize_usage_from_packet(
|
|
session,
|
|
bot_id,
|
|
{
|
|
**packet,
|
|
"message_id": persisted_message_id,
|
|
},
|
|
)
|
|
elif packet_type == "USER_COMMAND":
|
|
text_msg = str(packet.get("text") or "").strip()
|
|
media_list = self.normalize_media_list(packet.get("media"), bot_id)
|
|
if text_msg or media_list:
|
|
message_row = BotMessage(
|
|
bot_id=bot_id,
|
|
role="user",
|
|
text=text_msg,
|
|
media_json=json.dumps(media_list, ensure_ascii=False) if media_list else None,
|
|
)
|
|
session.add(message_row)
|
|
session.flush()
|
|
persisted_message_id = message_row.id
|
|
self._bind_usage_message(
|
|
session,
|
|
bot_id,
|
|
str(packet.get("request_id") or "").strip(),
|
|
persisted_message_id,
|
|
)
|
|
elif packet_type == "BUS_EVENT":
|
|
is_progress = bool(packet.get("is_progress"))
|
|
detail_text = str(packet.get("content") or packet.get("text") or "").strip()
|
|
if not is_progress:
|
|
text_msg = detail_text
|
|
media_list = self.normalize_media_list(packet.get("media"), bot_id)
|
|
if text_msg or media_list:
|
|
bot.current_state = "IDLE"
|
|
if text_msg:
|
|
bot.last_action = self.normalize_last_action_text(text_msg)
|
|
message_row = BotMessage(
|
|
bot_id=bot_id,
|
|
role="assistant",
|
|
text=text_msg,
|
|
media_json=json.dumps(media_list, ensure_ascii=False) if media_list else None,
|
|
)
|
|
session.add(message_row)
|
|
session.flush()
|
|
persisted_message_id = message_row.id
|
|
self._finalize_usage_from_packet(
|
|
session,
|
|
bot_id,
|
|
{
|
|
"text": text_msg,
|
|
"usage": packet.get("usage"),
|
|
"request_id": packet.get("request_id"),
|
|
"provider": packet.get("provider"),
|
|
"model": packet.get("model"),
|
|
"message_id": persisted_message_id,
|
|
},
|
|
)
|
|
|
|
bot.updated_at = datetime.utcnow()
|
|
session.add(bot)
|
|
session.commit()
|
|
|
|
self._publish_runtime_topic_packet(
|
|
self._engine,
|
|
bot_id,
|
|
packet,
|
|
source_channel,
|
|
persisted_message_id,
|
|
self._logger,
|
|
)
|
|
|
|
if persisted_message_id:
|
|
packet["message_id"] = persisted_message_id
|
|
if packet_type in {"ASSISTANT_MESSAGE", "USER_COMMAND", "BUS_EVENT"}:
|
|
self.invalidate_bot_messages_cache(bot_id)
|
|
self.invalidate_bot_detail_cache(bot_id)
|
|
return persisted_message_id
|
|
|
|
def broadcast_runtime_packet(self, bot_id: str, packet: Dict[str, Any], loop: Any) -> None:
|
|
asyncio.run_coroutine_threadsafe(self.manager.broadcast(bot_id, packet), loop)
|
|
|
|
def docker_callback(self, bot_id: str, packet: Dict[str, Any]):
|
|
self.persist_runtime_packet(bot_id, packet)
|
|
loop = getattr(self._app.state, "main_loop", None)
|
|
if not loop or not loop.is_running():
|
|
return
|
|
asyncio.run_coroutine_threadsafe(self.manager.broadcast(bot_id, packet), loop)
|