from __future__ import annotations from typing import Optional from fastapi import Request from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from core.settings import PANEL_ACCESS_PASSWORD from services.bot_storage_service import _read_bot_config PANEL_ACCESS_PASSWORD_HEADER = "x-panel-password" BOT_ACCESS_PASSWORD_HEADER = "X-Bot-Access-Password" BOT_PANEL_ONLY_SUFFIXES = {"/enable", "/disable", "/deactivate"} def _extract_bot_id_from_api_path(path: str) -> Optional[str]: parts = [p for p in path.split("/") if p.strip()] if len(parts) >= 3 and parts[0] == "api" and parts[1] == "bots": return parts[2] return None def _get_supplied_panel_password_http(request: Request) -> str: header_value = str(request.headers.get(PANEL_ACCESS_PASSWORD_HEADER) or "").strip() if header_value: return header_value query_value = str(request.query_params.get("panel_access_password") or "").strip() return query_value def _get_supplied_bot_access_password_http(request: Request) -> str: header_value = str(request.headers.get(BOT_ACCESS_PASSWORD_HEADER) or "").strip() if header_value: return header_value query_value = str(request.query_params.get("bot_access_password") or "").strip() return query_value def _validate_panel_access_password(supplied: str) -> Optional[str]: configured = str(PANEL_ACCESS_PASSWORD or "").strip() if not configured: return None candidate = str(supplied or "").strip() if not candidate: return "Panel access password required" if candidate != configured: return "Invalid panel access password" return None def _validate_bot_access_password(bot_id: str, supplied: str) -> Optional[str]: config = _read_bot_config(bot_id) configured = str(config.get("access_password") or "").strip() if not configured: return None candidate = str(supplied or "").strip() if not candidate: return "Bot access password required" if candidate != configured: return "Invalid bot access password" return None def _is_bot_panel_management_api_path(path: str, method: str = "GET") -> bool: raw = str(path or "").strip() verb = str(method or "GET").strip().upper() if not raw.startswith("/api/bots/"): return False bot_id = _extract_bot_id_from_api_path(raw) if not bot_id: return False return ( raw.endswith("/start") or raw.endswith("/stop") or raw.endswith("/enable") or raw.endswith("/disable") or raw.endswith("/deactivate") or (verb in {"PUT", "DELETE"} and raw == f"/api/bots/{bot_id}") ) def _is_panel_protected_api_path(path: str, method: str = "GET") -> bool: raw = str(path or "").strip() verb = str(method or "GET").strip().upper() if not raw.startswith("/api/"): return False if raw in { "/api/panel/auth/status", "/api/panel/auth/login", "/api/health", "/api/health/cache", }: return False if _is_bot_panel_management_api_path(raw, verb): return True if _extract_bot_id_from_api_path(raw): return False return True class PasswordProtectionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): path = request.url.path method = request.method.upper() if method == "OPTIONS": return await call_next(request) bot_id = _extract_bot_id_from_api_path(path) if not bot_id: if _is_panel_protected_api_path(path, method): panel_error = _validate_panel_access_password(_get_supplied_panel_password_http(request)) if panel_error: return JSONResponse(status_code=401, content={"detail": panel_error}) return await call_next(request) if _is_bot_panel_management_api_path(path, method): panel_error = _validate_panel_access_password(_get_supplied_panel_password_http(request)) if panel_error: bot_error = _validate_bot_access_password(bot_id, _get_supplied_bot_access_password_http(request)) if bot_error: return JSONResponse(status_code=401, content={"detail": bot_error}) return await call_next(request)