166 lines
7.7 KiB
Python
166 lines
7.7 KiB
Python
import asyncio
|
|
from typing import Any, Callable
|
|
|
|
from fastapi import HTTPException, WebSocket, WebSocketDisconnect
|
|
from sqlmodel import Session, select
|
|
|
|
from models.bot import BotInstance
|
|
from models.platform import BotRequestUsage
|
|
|
|
|
|
class AppLifecycleService:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
app: Any,
|
|
engine: Any,
|
|
cache: Any,
|
|
logger: Any,
|
|
project_root: str,
|
|
database_engine: str,
|
|
database_echo: Any,
|
|
database_url_display: str,
|
|
redis_enabled: bool,
|
|
init_database: Callable[[], None],
|
|
node_registry_service: Any,
|
|
local_managed_node: Callable[[], Any],
|
|
prune_expired_activity_events: Callable[..., int],
|
|
migrate_bot_resources_store: Callable[[str], None],
|
|
resolve_bot_provider_target_for_instance: Callable[[Any], Any],
|
|
default_provider_target: Callable[[], Any],
|
|
set_bot_provider_target: Callable[[str, Any], None],
|
|
apply_provider_target_to_bot: Callable[[Any, Any], None],
|
|
normalize_provider_target: Callable[[Any], Any],
|
|
runtime_service: Any,
|
|
runtime_event_service: Any,
|
|
clear_provider_target_overrides: Callable[[], None],
|
|
) -> None:
|
|
self._app = app
|
|
self._engine = engine
|
|
self._cache = cache
|
|
self._logger = logger
|
|
self._project_root = project_root
|
|
self._database_engine = database_engine
|
|
self._database_echo = database_echo
|
|
self._database_url_display = database_url_display
|
|
self._redis_enabled = redis_enabled
|
|
self._init_database = init_database
|
|
self._node_registry_service = node_registry_service
|
|
self._local_managed_node = local_managed_node
|
|
self._prune_expired_activity_events = prune_expired_activity_events
|
|
self._migrate_bot_resources_store = migrate_bot_resources_store
|
|
self._resolve_bot_provider_target_for_instance = resolve_bot_provider_target_for_instance
|
|
self._default_provider_target = default_provider_target
|
|
self._set_bot_provider_target = set_bot_provider_target
|
|
self._apply_provider_target_to_bot = apply_provider_target_to_bot
|
|
self._normalize_provider_target = normalize_provider_target
|
|
self._runtime_service = runtime_service
|
|
self._runtime_event_service = runtime_event_service
|
|
self._clear_provider_target_overrides = clear_provider_target_overrides
|
|
|
|
async def on_startup(self) -> None:
|
|
self._app.state.main_loop = asyncio.get_running_loop()
|
|
self._clear_provider_target_overrides()
|
|
self._logger.info(
|
|
"startup project_root=%s db_engine=%s db_echo=%s db_url=%s redis=%s",
|
|
self._project_root,
|
|
self._database_engine,
|
|
self._database_echo,
|
|
self._database_url_display,
|
|
"enabled" if self._cache.ping() else ("disabled" if self._redis_enabled else "not_configured"),
|
|
)
|
|
self._init_database()
|
|
self._cache.delete_prefix("")
|
|
with Session(self._engine) as session:
|
|
self._node_registry_service.load_from_session(session)
|
|
self._node_registry_service.upsert_node(session, self._local_managed_node())
|
|
pruned_events = self._prune_expired_activity_events(session, force=True)
|
|
if pruned_events > 0:
|
|
session.commit()
|
|
target_dirty = False
|
|
for bot in session.exec(select(BotInstance)).all():
|
|
self._migrate_bot_resources_store(bot.id)
|
|
target = self._resolve_bot_provider_target_for_instance(bot)
|
|
if str(target.transport_kind or "").strip().lower() != "edge":
|
|
target = self._normalize_provider_target(
|
|
{
|
|
"node_id": target.node_id,
|
|
"transport_kind": "edge",
|
|
"runtime_kind": target.runtime_kind,
|
|
"core_adapter": target.core_adapter,
|
|
},
|
|
fallback=self._default_provider_target(),
|
|
)
|
|
self._set_bot_provider_target(bot.id, target)
|
|
if (
|
|
str(getattr(bot, "node_id", "") or "").strip().lower() != target.node_id
|
|
or str(getattr(bot, "transport_kind", "") or "").strip().lower() != target.transport_kind
|
|
or str(getattr(bot, "runtime_kind", "") or "").strip().lower() != target.runtime_kind
|
|
or str(getattr(bot, "core_adapter", "") or "").strip().lower() != target.core_adapter
|
|
):
|
|
self._apply_provider_target_to_bot(bot, target)
|
|
session.add(bot)
|
|
target_dirty = True
|
|
if target_dirty:
|
|
session.commit()
|
|
running_bots = session.exec(select(BotInstance).where(BotInstance.docker_status == "RUNNING")).all()
|
|
for bot in running_bots:
|
|
try:
|
|
self._runtime_service.ensure_monitor(app_state=self._app.state, bot=bot)
|
|
pending_usage = session.exec(
|
|
select(BotRequestUsage)
|
|
.where(BotRequestUsage.bot_id == str(bot.id or "").strip())
|
|
.where(BotRequestUsage.status == "PENDING")
|
|
.order_by(BotRequestUsage.started_at.desc(), BotRequestUsage.id.desc())
|
|
.limit(1)
|
|
).first()
|
|
if pending_usage and str(getattr(pending_usage, "request_id", "") or "").strip():
|
|
self._runtime_service.sync_edge_monitor_packets(
|
|
app_state=self._app.state,
|
|
bot=bot,
|
|
request_id=str(pending_usage.request_id or "").strip(),
|
|
)
|
|
except HTTPException as exc:
|
|
self._logger.warning(
|
|
"Skip runtime monitor restore on startup for bot_id=%s due to unavailable runtime backend: %s",
|
|
str(bot.id or ""),
|
|
str(getattr(exc, "detail", "") or exc),
|
|
)
|
|
except Exception:
|
|
self._logger.exception("Failed to restore runtime monitor on startup for bot_id=%s", str(bot.id or ""))
|
|
|
|
async def handle_websocket(self, websocket: WebSocket, bot_id: str) -> None:
|
|
with Session(self._engine) as session:
|
|
bot = session.get(BotInstance, bot_id)
|
|
if not bot:
|
|
await websocket.close(code=4404, reason="Bot not found")
|
|
return
|
|
|
|
connected = False
|
|
try:
|
|
await self._runtime_event_service.manager.connect(bot_id, websocket)
|
|
connected = True
|
|
except Exception as exc:
|
|
self._logger.warning("websocket connect failed bot_id=%s detail=%s", bot_id, exc)
|
|
try:
|
|
await websocket.close(code=1011, reason="WebSocket accept failed")
|
|
except Exception:
|
|
pass
|
|
return
|
|
|
|
self._runtime_service.ensure_monitor(app_state=websocket.app.state, bot=bot)
|
|
try:
|
|
while True:
|
|
await websocket.receive_text()
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except RuntimeError as exc:
|
|
msg = str(exc or "").lower()
|
|
if "need to call \"accept\" first" not in msg and "not connected" not in msg:
|
|
self._logger.exception("websocket runtime error bot_id=%s", bot_id)
|
|
except Exception:
|
|
self._logger.exception("websocket unexpected error bot_id=%s", bot_id)
|
|
finally:
|
|
if connected:
|
|
self._runtime_event_service.manager.disconnect(bot_id, websocket)
|