from typing import Any, Callable, Optional from urllib.parse import unquote from fastapi import Request from fastapi.responses import JSONResponse from sqlmodel import Session from models.bot import BotInstance class DashboardAuthService: AUTH_TOKEN_HEADER = "authorization" AUTH_TOKEN_FALLBACK_HEADER = "x-auth-token" def __init__(self, *, engine: Any) -> None: self._engine = engine def extract_bot_id_from_api_path(self, path: str) -> Optional[str]: raw = str(path or "").strip() if not raw.startswith("/api/bots/"): return None rest = raw[len("/api/bots/") :] if not rest: return None bot_id_segment = rest.split("/", 1)[0].strip() if not bot_id_segment: return None try: decoded = unquote(bot_id_segment) except Exception: decoded = bot_id_segment return str(decoded).strip() or None def get_supplied_auth_token_http(self, request: Request) -> str: auth_header = str(request.headers.get(self.AUTH_TOKEN_HEADER) or "").strip() if auth_header.lower().startswith("bearer "): token = auth_header[7:].strip() if token: return token header_value = str(request.headers.get(self.AUTH_TOKEN_FALLBACK_HEADER) or "").strip() if header_value: return header_value return str(request.query_params.get("auth_token") or "").strip() @staticmethod def is_public_api_path(path: str, method: str = "GET") -> bool: raw = str(path or "").strip() if not raw.startswith("/api/"): return False return raw in { "/api/sys/auth/status", "/api/sys/auth/login", "/api/sys/auth/logout", "/api/health", "/api/health/cache", } def is_bot_enable_api_path(self, path: str, method: str = "GET") -> bool: raw = str(path or "").strip() verb = str(method or "GET").strip().upper() if verb != "POST": return False bot_id = self.extract_bot_id_from_api_path(raw) if not bot_id: return False return raw == f"/api/bots/{bot_id}/enable" def validate_dashboard_auth(self, request: Request, session: Session) -> Optional[str]: token = self.get_supplied_auth_token_http(request) if not token: return "Authentication required" from services.sys_auth_service import resolve_user_by_token user = resolve_user_by_token(session, token) if user is None: return "Session expired or invalid" request.state.sys_auth_mode = "session_token" request.state.sys_user_id = int(user.id or 0) request.state.sys_username = str(user.username or "") return None @staticmethod def _json_error(request: Request, *, status_code: int, detail: str) -> JSONResponse: headers = {"Access-Control-Allow-Origin": "*"} origin = str(request.headers.get("origin") or "").strip() if origin: headers["Vary"] = "Origin" return JSONResponse(status_code=status_code, content={"detail": detail}, headers=headers) async def guard(self, request: Request, call_next: Callable[..., Any]): if request.method.upper() == "OPTIONS": return await call_next(request) if self.is_public_api_path(request.url.path, request.method): return await call_next(request) current_user_id = 0 with Session(self._engine) as session: auth_error = self.validate_dashboard_auth(request, session) if auth_error: return self._json_error(request, status_code=401, detail=auth_error) current_user_id = int(getattr(request.state, "sys_user_id", 0) or 0) bot_id = self.extract_bot_id_from_api_path(request.url.path) if not bot_id: return await call_next(request) with Session(self._engine) as session: from models.sys_auth import SysUser from services.sys_auth_service import user_can_access_bot current_user = session.get(SysUser, current_user_id) if current_user_id > 0 else None if current_user is None: return self._json_error(request, status_code=401, detail="Authentication required") if not user_can_access_bot(session, current_user, bot_id): return self._json_error(request, status_code=403, detail="You do not have access to this bot") bot = session.get(BotInstance, bot_id) if not bot: return self._json_error(request, status_code=404, detail="Bot not found") enabled = bool(getattr(bot, "enabled", True)) if not enabled: is_enable_api = self.is_bot_enable_api_path(request.url.path, request.method) is_read_api = request.method.upper() == "GET" if not (is_enable_api or is_read_api): return self._json_error(request, status_code=403, detail="Bot is disabled. Enable it first.") return await call_next(request)