130 lines
5.1 KiB
Python
130 lines
5.1 KiB
Python
from typing import Any, Callable, Optional
|
|
from urllib.parse import unquote
|
|
|
|
from fastapi import Request
|
|
from fastapi.responses import JSONResponse
|
|
from sqlmodel import Session
|
|
|
|
from models.bot import BotInstance
|
|
|
|
|
|
class DashboardAuthService:
|
|
AUTH_TOKEN_HEADER = "authorization"
|
|
AUTH_TOKEN_FALLBACK_HEADER = "x-auth-token"
|
|
|
|
def __init__(self, *, engine: Any) -> None:
|
|
self._engine = engine
|
|
|
|
def extract_bot_id_from_api_path(self, path: str) -> Optional[str]:
|
|
raw = str(path or "").strip()
|
|
if not raw.startswith("/api/bots/"):
|
|
return None
|
|
rest = raw[len("/api/bots/") :]
|
|
if not rest:
|
|
return None
|
|
bot_id_segment = rest.split("/", 1)[0].strip()
|
|
if not bot_id_segment:
|
|
return None
|
|
try:
|
|
decoded = unquote(bot_id_segment)
|
|
except Exception:
|
|
decoded = bot_id_segment
|
|
return str(decoded).strip() or None
|
|
|
|
def get_supplied_auth_token_http(self, request: Request) -> str:
|
|
auth_header = str(request.headers.get(self.AUTH_TOKEN_HEADER) or "").strip()
|
|
if auth_header.lower().startswith("bearer "):
|
|
token = auth_header[7:].strip()
|
|
if token:
|
|
return token
|
|
header_value = str(request.headers.get(self.AUTH_TOKEN_FALLBACK_HEADER) or "").strip()
|
|
if header_value:
|
|
return header_value
|
|
return str(request.query_params.get("auth_token") or "").strip()
|
|
|
|
@staticmethod
|
|
def is_public_api_path(path: str, method: str = "GET") -> bool:
|
|
raw = str(path or "").strip()
|
|
if not raw.startswith("/api/"):
|
|
return False
|
|
return raw in {
|
|
"/api/sys/auth/status",
|
|
"/api/sys/auth/login",
|
|
"/api/sys/auth/logout",
|
|
"/api/health",
|
|
"/api/health/cache",
|
|
}
|
|
|
|
def is_bot_enable_api_path(self, path: str, method: str = "GET") -> bool:
|
|
raw = str(path or "").strip()
|
|
verb = str(method or "GET").strip().upper()
|
|
if verb != "POST":
|
|
return False
|
|
bot_id = self.extract_bot_id_from_api_path(raw)
|
|
if not bot_id:
|
|
return False
|
|
return raw == f"/api/bots/{bot_id}/enable"
|
|
|
|
def validate_dashboard_auth(self, request: Request, session: Session) -> Optional[str]:
|
|
token = self.get_supplied_auth_token_http(request)
|
|
if not token:
|
|
return "Authentication required"
|
|
|
|
from services.sys_auth_service import resolve_user_by_token
|
|
|
|
user = resolve_user_by_token(session, token)
|
|
if user is None:
|
|
return "Session expired or invalid"
|
|
|
|
request.state.sys_auth_mode = "session_token"
|
|
request.state.sys_user_id = int(user.id or 0)
|
|
request.state.sys_username = str(user.username or "")
|
|
return None
|
|
|
|
@staticmethod
|
|
def _json_error(request: Request, *, status_code: int, detail: str) -> JSONResponse:
|
|
headers = {"Access-Control-Allow-Origin": "*"}
|
|
origin = str(request.headers.get("origin") or "").strip()
|
|
if origin:
|
|
headers["Vary"] = "Origin"
|
|
return JSONResponse(status_code=status_code, content={"detail": detail}, headers=headers)
|
|
|
|
async def guard(self, request: Request, call_next: Callable[..., Any]):
|
|
if request.method.upper() == "OPTIONS":
|
|
return await call_next(request)
|
|
|
|
if self.is_public_api_path(request.url.path, request.method):
|
|
return await call_next(request)
|
|
|
|
current_user_id = 0
|
|
with Session(self._engine) as session:
|
|
auth_error = self.validate_dashboard_auth(request, session)
|
|
if auth_error:
|
|
return self._json_error(request, status_code=401, detail=auth_error)
|
|
current_user_id = int(getattr(request.state, "sys_user_id", 0) or 0)
|
|
|
|
bot_id = self.extract_bot_id_from_api_path(request.url.path)
|
|
if not bot_id:
|
|
return await call_next(request)
|
|
|
|
with Session(self._engine) as session:
|
|
from models.sys_auth import SysUser
|
|
from services.sys_auth_service import user_can_access_bot
|
|
|
|
current_user = session.get(SysUser, current_user_id) if current_user_id > 0 else None
|
|
if current_user is None:
|
|
return self._json_error(request, status_code=401, detail="Authentication required")
|
|
if not user_can_access_bot(session, current_user, bot_id):
|
|
return self._json_error(request, status_code=403, detail="You do not have access to this bot")
|
|
bot = session.get(BotInstance, bot_id)
|
|
if not bot:
|
|
return self._json_error(request, status_code=404, detail="Bot not found")
|
|
|
|
enabled = bool(getattr(bot, "enabled", True))
|
|
if not enabled:
|
|
is_enable_api = self.is_bot_enable_api_path(request.url.path, request.method)
|
|
is_read_api = request.method.upper() == "GET"
|
|
if not (is_enable_api or is_read_api):
|
|
return self._json_error(request, status_code=403, detail="Bot is disabled. Enable it first.")
|
|
return await call_next(request)
|