334 lines
14 KiB
Python
334 lines
14 KiB
Python
import asyncio
|
||
import os
|
||
import threading
|
||
import time
|
||
from datetime import datetime, timezone
|
||
from typing import Any, Callable, Dict, List, Optional
|
||
|
||
from fastapi import HTTPException
|
||
from sqlmodel import Session
|
||
|
||
from models.bot import BotInstance
|
||
from providers.runtime.base import RuntimeProvider
|
||
|
||
|
||
class BotCommandService:
|
||
def __init__(
|
||
self,
|
||
*,
|
||
read_runtime_snapshot: Callable[[BotInstance], Dict[str, Any]],
|
||
normalize_media_list: Callable[[Any, str], List[str]],
|
||
resolve_workspace_path: Callable[[str, Optional[str]], tuple[str, str]],
|
||
is_visual_attachment_path: Callable[[str], bool],
|
||
is_video_attachment_path: Callable[[str], bool],
|
||
create_usage_request: Callable[..., str],
|
||
record_activity_event: Callable[..., None],
|
||
fail_latest_usage: Callable[[Session, str, str], None],
|
||
persist_runtime_packet: Callable[[str, Dict[str, Any]], Optional[int]],
|
||
get_main_loop: Callable[[Any], Any],
|
||
broadcast_packet: Callable[[str, Dict[str, Any], Any], None],
|
||
) -> None:
|
||
self._read_runtime_snapshot = read_runtime_snapshot
|
||
self._normalize_media_list = normalize_media_list
|
||
self._resolve_workspace_path = resolve_workspace_path
|
||
self._is_visual_attachment_path = is_visual_attachment_path
|
||
self._is_video_attachment_path = is_video_attachment_path
|
||
self._create_usage_request = create_usage_request
|
||
self._record_activity_event = record_activity_event
|
||
self._fail_latest_usage = fail_latest_usage
|
||
self._persist_runtime_packet = persist_runtime_packet
|
||
self._get_main_loop = get_main_loop
|
||
self._broadcast_packet = broadcast_packet
|
||
self._monitor_sync_threads: Dict[tuple[str, str], threading.Thread] = {}
|
||
self._monitor_sync_lock = threading.Lock()
|
||
self._monitor_sync_seq_lock = threading.Lock()
|
||
self._monitor_sync_last_seq: Dict[str, int] = {}
|
||
|
||
def execute(
|
||
self,
|
||
*,
|
||
session: Session,
|
||
bot_id: str,
|
||
bot: BotInstance,
|
||
payload: Any,
|
||
runtime_provider: RuntimeProvider,
|
||
app_state: Any,
|
||
) -> Dict[str, Any]:
|
||
runtime_snapshot = self._read_runtime_snapshot(bot)
|
||
attachments = self._normalize_media_list(getattr(payload, "attachments", None), bot_id)
|
||
command = str(getattr(payload, "command", None) or "").strip()
|
||
if not command and not attachments:
|
||
raise HTTPException(status_code=400, detail="Command or attachments is required")
|
||
|
||
checked_attachments: List[str] = []
|
||
transport_kind = str(getattr(bot, "transport_kind", "") or "").strip().lower()
|
||
for rel in attachments:
|
||
if transport_kind != "edge":
|
||
_, target = self._resolve_workspace_path(bot_id, rel)
|
||
if not os.path.isfile(target):
|
||
raise HTTPException(status_code=400, detail=f"attachment not found: {rel}")
|
||
checked_attachments.append(rel)
|
||
delivery_media = [f"/root/.nanobot/workspace/{p.lstrip('/')}" for p in checked_attachments]
|
||
|
||
display_command = command if command else "[attachment message]"
|
||
delivery_command = self._build_delivery_command(command=command, checked_attachments=checked_attachments)
|
||
|
||
request_id = self._create_usage_request(
|
||
session,
|
||
bot_id,
|
||
display_command,
|
||
attachments=checked_attachments,
|
||
channel="dashboard",
|
||
metadata={"attachment_count": len(checked_attachments)},
|
||
provider=str(runtime_snapshot.get("llm_provider") or "").strip() or None,
|
||
model=str(runtime_snapshot.get("llm_model") or "").strip() or None,
|
||
)
|
||
self._record_activity_event(
|
||
session,
|
||
bot_id,
|
||
"command_submitted",
|
||
request_id=request_id,
|
||
channel="dashboard",
|
||
detail="command submitted",
|
||
metadata={
|
||
"attachment_count": len(checked_attachments),
|
||
"has_text": bool(command),
|
||
},
|
||
)
|
||
session.commit()
|
||
|
||
outbound_user_packet: Optional[Dict[str, Any]] = None
|
||
if display_command or checked_attachments:
|
||
outbound_user_packet = {
|
||
"type": "USER_COMMAND",
|
||
"channel": "dashboard",
|
||
"text": display_command,
|
||
"media": checked_attachments,
|
||
"request_id": request_id,
|
||
}
|
||
self._persist_runtime_packet(bot_id, outbound_user_packet)
|
||
|
||
loop = self._get_main_loop(app_state)
|
||
if loop and loop.is_running() and outbound_user_packet:
|
||
self._broadcast_packet(bot_id, outbound_user_packet, loop)
|
||
|
||
detail = runtime_provider.deliver_command(bot_id=bot_id, command=delivery_command, media=delivery_media)
|
||
if detail is not None:
|
||
self._fail_latest_usage(session, bot_id, detail or "command delivery failed")
|
||
self._record_activity_event(
|
||
session,
|
||
bot_id,
|
||
"command_failed",
|
||
request_id=request_id,
|
||
channel="dashboard",
|
||
detail=(detail or "command delivery failed")[:400],
|
||
)
|
||
session.commit()
|
||
if loop and loop.is_running():
|
||
self._broadcast_packet(
|
||
bot_id,
|
||
{
|
||
"type": "AGENT_STATE",
|
||
"channel": "dashboard",
|
||
"payload": {
|
||
"state": "ERROR",
|
||
"action_msg": detail or "command delivery failed",
|
||
},
|
||
},
|
||
loop,
|
||
)
|
||
raise HTTPException(
|
||
status_code=502,
|
||
detail=f"Failed to deliver command to bot dashboard channel{': ' + detail if detail else ''}",
|
||
)
|
||
|
||
self._maybe_sync_edge_monitor_packets(
|
||
runtime_provider=runtime_provider,
|
||
bot_id=bot_id,
|
||
request_id=request_id,
|
||
after_seq=self._resolve_monitor_baseline_seq(runtime_provider, bot_id),
|
||
app_state=app_state,
|
||
)
|
||
return {"success": True}
|
||
|
||
def _maybe_sync_edge_monitor_packets(
|
||
self,
|
||
*,
|
||
runtime_provider: RuntimeProvider,
|
||
bot_id: str,
|
||
request_id: str,
|
||
after_seq: int,
|
||
app_state: Any,
|
||
) -> None:
|
||
provider_name = runtime_provider.__class__.__name__.strip().lower()
|
||
if provider_name != "edgeruntimeprovider":
|
||
return
|
||
bot_key = str(bot_id or "").strip()
|
||
if not bot_key:
|
||
return
|
||
request_key = str(request_id or "").strip() or f"seq:{int(after_seq or 0)}"
|
||
thread_key = (bot_key, request_key)
|
||
with self._monitor_sync_lock:
|
||
existing = self._monitor_sync_threads.get(thread_key)
|
||
if existing and existing.is_alive():
|
||
return
|
||
thread = threading.Thread(
|
||
target=self._sync_edge_monitor_packets,
|
||
args=(runtime_provider, bot_key, request_id, after_seq, app_state),
|
||
daemon=True,
|
||
)
|
||
self._monitor_sync_threads[thread_key] = thread
|
||
thread.start()
|
||
|
||
def sync_edge_monitor_packets(
|
||
self,
|
||
*,
|
||
runtime_provider: RuntimeProvider,
|
||
bot_id: str,
|
||
request_id: str,
|
||
app_state: Any,
|
||
) -> None:
|
||
self._maybe_sync_edge_monitor_packets(
|
||
runtime_provider=runtime_provider,
|
||
bot_id=bot_id,
|
||
request_id=request_id,
|
||
after_seq=0,
|
||
app_state=app_state,
|
||
)
|
||
|
||
def _sync_edge_monitor_packets(
|
||
self,
|
||
runtime_provider: RuntimeProvider,
|
||
bot_id: str,
|
||
request_id: str,
|
||
after_seq: int,
|
||
app_state: Any,
|
||
) -> None:
|
||
loop = self._get_main_loop(app_state)
|
||
last_seq = max(0, int(after_seq or 0))
|
||
deadline = time.monotonic() + 18.0
|
||
request_id_norm = str(request_id or "").strip()
|
||
try:
|
||
while time.monotonic() < deadline:
|
||
try:
|
||
rows = runtime_provider.get_monitor_packets(bot_id=bot_id, after_seq=last_seq, limit=200)
|
||
except Exception:
|
||
time.sleep(0.5)
|
||
continue
|
||
|
||
for row in rows or []:
|
||
try:
|
||
seq = int(row.get("seq") or 0)
|
||
except Exception:
|
||
seq = 0
|
||
|
||
packet = dict(row.get("packet") or {})
|
||
if not packet:
|
||
continue
|
||
packet_type = str(packet.get("type") or "").strip().upper()
|
||
packet_request_id = str(packet.get("request_id") or "").strip()
|
||
if packet_type == "USER_COMMAND":
|
||
continue
|
||
if packet_type in {"ASSISTANT_MESSAGE", "BUS_EVENT"} and request_id_norm and packet_request_id and packet_request_id != request_id_norm:
|
||
continue
|
||
|
||
if not self._mark_monitor_seq(bot_id, seq):
|
||
continue
|
||
last_seq = max(last_seq, seq)
|
||
|
||
self._persist_runtime_packet(bot_id, packet)
|
||
if loop and loop.is_running():
|
||
self._broadcast_packet(bot_id, packet, loop)
|
||
time.sleep(0.5)
|
||
finally:
|
||
with self._monitor_sync_lock:
|
||
request_key = request_id_norm or f"seq:{int(after_seq or 0)}"
|
||
existing = self._monitor_sync_threads.get((bot_id, request_key))
|
||
if existing is threading.current_thread():
|
||
self._monitor_sync_threads.pop((bot_id, request_key), None)
|
||
|
||
def _resolve_monitor_baseline_seq(self, runtime_provider: RuntimeProvider, bot_id: str) -> int:
|
||
try:
|
||
rows = runtime_provider.get_monitor_packets(bot_id=bot_id, after_seq=0, limit=1000)
|
||
except Exception:
|
||
return self._get_monitor_seq(bot_id)
|
||
latest_seq = 0
|
||
for row in rows or []:
|
||
try:
|
||
seq = int(row.get("seq") or 0)
|
||
except Exception:
|
||
seq = 0
|
||
latest_seq = max(latest_seq, seq)
|
||
return max(latest_seq, self._get_monitor_seq(bot_id))
|
||
|
||
def _mark_monitor_seq(self, bot_id: str, seq: int) -> bool:
|
||
if seq <= 0:
|
||
return False
|
||
bot_key = str(bot_id or "").strip()
|
||
with self._monitor_sync_seq_lock:
|
||
current = int(self._monitor_sync_last_seq.get(bot_key, 0) or 0)
|
||
if seq <= current:
|
||
return False
|
||
self._monitor_sync_last_seq[bot_key] = seq
|
||
return True
|
||
|
||
def _get_monitor_seq(self, bot_id: str) -> int:
|
||
bot_key = str(bot_id or "").strip()
|
||
with self._monitor_sync_seq_lock:
|
||
return int(self._monitor_sync_last_seq.get(bot_key, 0) or 0)
|
||
|
||
def _build_delivery_command(self, *, command: str, checked_attachments: List[str]) -> str:
|
||
display_command = command if command else "[attachment message]"
|
||
delivery_command = display_command
|
||
if not checked_attachments:
|
||
return delivery_command
|
||
|
||
attachment_block = "\n".join(f"- {p}" for p in checked_attachments)
|
||
all_visual = all(self._is_visual_attachment_path(p) for p in checked_attachments)
|
||
if all_visual:
|
||
has_video = any(self._is_video_attachment_path(p) for p in checked_attachments)
|
||
media_label = "图片/视频" if has_video else "图片"
|
||
capability_hint = (
|
||
"1) 附件已随请求附带;图片在可用时可直接作为多模态输入理解,视频请按附件路径处理。\n"
|
||
if has_video
|
||
else "1) 附件中的图片已作为多模态输入提供,优先直接理解并回答。\n"
|
||
)
|
||
if command:
|
||
return (
|
||
f"{command}\n\n"
|
||
"[Attached files]\n"
|
||
f"{attachment_block}\n\n"
|
||
"【附件处理要求】\n"
|
||
f"{capability_hint}"
|
||
"2) 若当前模型或接口不支持直接理解该附件,请明确说明后再调用工具解析。\n"
|
||
"3) 除非用户明确要求,不要先调用工具读取附件文件。\n"
|
||
"4) 回复语言必须遵循 USER.md;若未指定,则与用户当前输入语言保持一致。\n"
|
||
"5) 仅基于可见内容回答;看不清或无法确认的部分请明确说明,不要猜测。"
|
||
)
|
||
return (
|
||
"请先处理已附带的附件列表:\n"
|
||
f"{attachment_block}\n\n"
|
||
f"请直接分析已附带的{media_label}并总结关键信息。\n"
|
||
f"{'图片在可用时可直接作为多模态输入理解,视频请按附件路径处理。' if has_video else ''}\n"
|
||
"若当前模型或接口不支持直接理解该附件,请明确说明后再调用工具解析。\n"
|
||
"回复语言必须遵循 USER.md;若未指定,则与用户当前输入语言保持一致。\n"
|
||
"仅基于可见内容回答;看不清或无法确认的部分请明确说明,不要猜测。"
|
||
)
|
||
|
||
command_has_paths = all(p in command for p in checked_attachments) if command else False
|
||
if command and not command_has_paths:
|
||
return (
|
||
f"{command}\n\n"
|
||
"[Attached files]\n"
|
||
f"{attachment_block}\n\n"
|
||
"Please process the attached file(s) listed above when answering this request.\n"
|
||
"Reply language must follow USER.md. If not specified, use the same language as the user input."
|
||
)
|
||
if not command:
|
||
return (
|
||
"Please process the uploaded file(s) listed below:\n"
|
||
f"{attachment_block}\n\n"
|
||
"Reply language must follow USER.md. If not specified, use the same language as the user input."
|
||
)
|
||
return delivery_command
|