dashboard-nanobot/backend/services/dashboard_auth_service.py

130 lines
5.1 KiB
Python

from typing import Any, Callable, Optional
from urllib.parse import unquote
from fastapi import Request
from fastapi.responses import JSONResponse
from sqlmodel import Session
from models.bot import BotInstance
class DashboardAuthService:
AUTH_TOKEN_HEADER = "authorization"
AUTH_TOKEN_FALLBACK_HEADER = "x-auth-token"
def __init__(self, *, engine: Any) -> None:
self._engine = engine
def extract_bot_id_from_api_path(self, path: str) -> Optional[str]:
raw = str(path or "").strip()
if not raw.startswith("/api/bots/"):
return None
rest = raw[len("/api/bots/") :]
if not rest:
return None
bot_id_segment = rest.split("/", 1)[0].strip()
if not bot_id_segment:
return None
try:
decoded = unquote(bot_id_segment)
except Exception:
decoded = bot_id_segment
return str(decoded).strip() or None
def get_supplied_auth_token_http(self, request: Request) -> str:
auth_header = str(request.headers.get(self.AUTH_TOKEN_HEADER) or "").strip()
if auth_header.lower().startswith("bearer "):
token = auth_header[7:].strip()
if token:
return token
header_value = str(request.headers.get(self.AUTH_TOKEN_FALLBACK_HEADER) or "").strip()
if header_value:
return header_value
return str(request.query_params.get("auth_token") or "").strip()
@staticmethod
def is_public_api_path(path: str, method: str = "GET") -> bool:
raw = str(path or "").strip()
if not raw.startswith("/api/"):
return False
return raw in {
"/api/sys/auth/status",
"/api/sys/auth/login",
"/api/sys/auth/logout",
"/api/health",
"/api/health/cache",
}
def is_bot_enable_api_path(self, path: str, method: str = "GET") -> bool:
raw = str(path or "").strip()
verb = str(method or "GET").strip().upper()
if verb != "POST":
return False
bot_id = self.extract_bot_id_from_api_path(raw)
if not bot_id:
return False
return raw == f"/api/bots/{bot_id}/enable"
def validate_dashboard_auth(self, request: Request, session: Session) -> Optional[str]:
token = self.get_supplied_auth_token_http(request)
if not token:
return "Authentication required"
from services.sys_auth_service import resolve_user_by_token
user = resolve_user_by_token(session, token)
if user is None:
return "Session expired or invalid"
request.state.sys_auth_mode = "session_token"
request.state.sys_user_id = int(user.id or 0)
request.state.sys_username = str(user.username or "")
return None
@staticmethod
def _json_error(request: Request, *, status_code: int, detail: str) -> JSONResponse:
headers = {"Access-Control-Allow-Origin": "*"}
origin = str(request.headers.get("origin") or "").strip()
if origin:
headers["Vary"] = "Origin"
return JSONResponse(status_code=status_code, content={"detail": detail}, headers=headers)
async def guard(self, request: Request, call_next: Callable[..., Any]):
if request.method.upper() == "OPTIONS":
return await call_next(request)
if self.is_public_api_path(request.url.path, request.method):
return await call_next(request)
current_user_id = 0
with Session(self._engine) as session:
auth_error = self.validate_dashboard_auth(request, session)
if auth_error:
return self._json_error(request, status_code=401, detail=auth_error)
current_user_id = int(getattr(request.state, "sys_user_id", 0) or 0)
bot_id = self.extract_bot_id_from_api_path(request.url.path)
if not bot_id:
return await call_next(request)
with Session(self._engine) as session:
from models.sys_auth import SysUser
from services.sys_auth_service import user_can_access_bot
current_user = session.get(SysUser, current_user_id) if current_user_id > 0 else None
if current_user is None:
return self._json_error(request, status_code=401, detail="Authentication required")
if not user_can_access_bot(session, current_user, bot_id):
return self._json_error(request, status_code=403, detail="You do not have access to this bot")
bot = session.get(BotInstance, bot_id)
if not bot:
return self._json_error(request, status_code=404, detail="Bot not found")
enabled = bool(getattr(bot, "enabled", True))
if not enabled:
is_enable_api = self.is_bot_enable_api_path(request.url.path, request.method)
is_read_api = request.method.upper() == "GET"
if not (is_enable_api or is_read_api):
return self._json_error(request, status_code=403, detail="Bot is disabled. Enable it first.")
return await call_next(request)