import asyncio from typing import Any, Callable from fastapi import HTTPException, WebSocket, WebSocketDisconnect from sqlmodel import Session, select from models.bot import BotInstance from models.platform import BotRequestUsage class AppLifecycleService: def __init__( self, *, app: Any, engine: Any, cache: Any, logger: Any, project_root: str, database_engine: str, database_echo: Any, database_url_display: str, redis_enabled: bool, init_database: Callable[[], None], node_registry_service: Any, local_managed_node: Callable[[], Any], prune_expired_activity_events: Callable[..., int], migrate_bot_resources_store: Callable[[str], None], resolve_bot_provider_target_for_instance: Callable[[Any], Any], default_provider_target: Callable[[], Any], set_bot_provider_target: Callable[[str, Any], None], apply_provider_target_to_bot: Callable[[Any, Any], None], normalize_provider_target: Callable[[Any], Any], runtime_service: Any, runtime_event_service: Any, clear_provider_target_overrides: Callable[[], None], ) -> None: self._app = app self._engine = engine self._cache = cache self._logger = logger self._project_root = project_root self._database_engine = database_engine self._database_echo = database_echo self._database_url_display = database_url_display self._redis_enabled = redis_enabled self._init_database = init_database self._node_registry_service = node_registry_service self._local_managed_node = local_managed_node self._prune_expired_activity_events = prune_expired_activity_events self._migrate_bot_resources_store = migrate_bot_resources_store self._resolve_bot_provider_target_for_instance = resolve_bot_provider_target_for_instance self._default_provider_target = default_provider_target self._set_bot_provider_target = set_bot_provider_target self._apply_provider_target_to_bot = apply_provider_target_to_bot self._normalize_provider_target = normalize_provider_target self._runtime_service = runtime_service self._runtime_event_service = runtime_event_service self._clear_provider_target_overrides = clear_provider_target_overrides async def on_startup(self) -> None: self._app.state.main_loop = asyncio.get_running_loop() self._clear_provider_target_overrides() self._logger.info( "startup project_root=%s db_engine=%s db_echo=%s db_url=%s redis=%s", self._project_root, self._database_engine, self._database_echo, self._database_url_display, "enabled" if self._cache.ping() else ("disabled" if self._redis_enabled else "not_configured"), ) self._init_database() self._cache.delete_prefix("") with Session(self._engine) as session: self._node_registry_service.load_from_session(session) self._node_registry_service.upsert_node(session, self._local_managed_node()) pruned_events = self._prune_expired_activity_events(session, force=True) if pruned_events > 0: session.commit() target_dirty = False for bot in session.exec(select(BotInstance)).all(): self._migrate_bot_resources_store(bot.id) target = self._resolve_bot_provider_target_for_instance(bot) if str(target.transport_kind or "").strip().lower() != "edge": target = self._normalize_provider_target( { "node_id": target.node_id, "transport_kind": "edge", "runtime_kind": target.runtime_kind, "core_adapter": target.core_adapter, }, fallback=self._default_provider_target(), ) self._set_bot_provider_target(bot.id, target) if ( str(getattr(bot, "node_id", "") or "").strip().lower() != target.node_id or str(getattr(bot, "transport_kind", "") or "").strip().lower() != target.transport_kind or str(getattr(bot, "runtime_kind", "") or "").strip().lower() != target.runtime_kind or str(getattr(bot, "core_adapter", "") or "").strip().lower() != target.core_adapter ): self._apply_provider_target_to_bot(bot, target) session.add(bot) target_dirty = True if target_dirty: session.commit() running_bots = session.exec(select(BotInstance).where(BotInstance.docker_status == "RUNNING")).all() for bot in running_bots: try: self._runtime_service.ensure_monitor(app_state=self._app.state, bot=bot) pending_usage = session.exec( select(BotRequestUsage) .where(BotRequestUsage.bot_id == str(bot.id or "").strip()) .where(BotRequestUsage.status == "PENDING") .order_by(BotRequestUsage.started_at.desc(), BotRequestUsage.id.desc()) .limit(1) ).first() if pending_usage and str(getattr(pending_usage, "request_id", "") or "").strip(): self._runtime_service.sync_edge_monitor_packets( app_state=self._app.state, bot=bot, request_id=str(pending_usage.request_id or "").strip(), ) except HTTPException as exc: self._logger.warning( "Skip runtime monitor restore on startup for bot_id=%s due to unavailable runtime backend: %s", str(bot.id or ""), str(getattr(exc, "detail", "") or exc), ) except Exception: self._logger.exception("Failed to restore runtime monitor on startup for bot_id=%s", str(bot.id or "")) async def handle_websocket(self, websocket: WebSocket, bot_id: str) -> None: with Session(self._engine) as session: bot = session.get(BotInstance, bot_id) if not bot: await websocket.close(code=4404, reason="Bot not found") return connected = False try: await self._runtime_event_service.manager.connect(bot_id, websocket) connected = True except Exception as exc: self._logger.warning("websocket connect failed bot_id=%s detail=%s", bot_id, exc) try: await websocket.close(code=1011, reason="WebSocket accept failed") except Exception: pass return self._runtime_service.ensure_monitor(app_state=websocket.app.state, bot=bot) try: while True: await websocket.receive_text() except WebSocketDisconnect: pass except RuntimeError as exc: msg = str(exc or "").lower() if "need to call \"accept\" first" not in msg and "not connected" not in msg: self._logger.exception("websocket runtime error bot_id=%s", bot_id) except Exception: self._logger.exception("websocket unexpected error bot_id=%s", bot_id) finally: if connected: self._runtime_event_service.manager.disconnect(bot_id, websocket)