126 lines
4.3 KiB
Python
126 lines
4.3 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Optional
|
|
|
|
from fastapi import Request
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from core.settings import PANEL_ACCESS_PASSWORD
|
|
from services.bot_storage_service import _read_bot_config
|
|
|
|
PANEL_ACCESS_PASSWORD_HEADER = "x-panel-password"
|
|
BOT_ACCESS_PASSWORD_HEADER = "X-Bot-Access-Password"
|
|
BOT_PANEL_ONLY_SUFFIXES = {"/enable", "/disable", "/deactivate"}
|
|
|
|
|
|
def _extract_bot_id_from_api_path(path: str) -> Optional[str]:
|
|
parts = [p for p in path.split("/") if p.strip()]
|
|
if len(parts) >= 3 and parts[0] == "api" and parts[1] == "bots":
|
|
return parts[2]
|
|
return None
|
|
|
|
|
|
def _get_supplied_panel_password_http(request: Request) -> str:
|
|
header_value = str(request.headers.get(PANEL_ACCESS_PASSWORD_HEADER) or "").strip()
|
|
if header_value:
|
|
return header_value
|
|
query_value = str(request.query_params.get("panel_access_password") or "").strip()
|
|
return query_value
|
|
|
|
|
|
def _get_supplied_bot_access_password_http(request: Request) -> str:
|
|
header_value = str(request.headers.get(BOT_ACCESS_PASSWORD_HEADER) or "").strip()
|
|
if header_value:
|
|
return header_value
|
|
query_value = str(request.query_params.get("bot_access_password") or "").strip()
|
|
return query_value
|
|
|
|
|
|
def _validate_panel_access_password(supplied: str) -> Optional[str]:
|
|
configured = str(PANEL_ACCESS_PASSWORD or "").strip()
|
|
if not configured:
|
|
return None
|
|
candidate = str(supplied or "").strip()
|
|
if not candidate:
|
|
return "Panel access password required"
|
|
if candidate != configured:
|
|
return "Invalid panel access password"
|
|
return None
|
|
|
|
|
|
def _validate_bot_access_password(bot_id: str, supplied: str) -> Optional[str]:
|
|
config = _read_bot_config(bot_id)
|
|
configured = str(config.get("access_password") or "").strip()
|
|
if not configured:
|
|
return None
|
|
candidate = str(supplied or "").strip()
|
|
if not candidate:
|
|
return "Bot access password required"
|
|
if candidate != configured:
|
|
return "Invalid bot access password"
|
|
return None
|
|
|
|
|
|
def _is_bot_panel_management_api_path(path: str, method: str = "GET") -> bool:
|
|
raw = str(path or "").strip()
|
|
verb = str(method or "GET").strip().upper()
|
|
if not raw.startswith("/api/bots/"):
|
|
return False
|
|
bot_id = _extract_bot_id_from_api_path(raw)
|
|
if not bot_id:
|
|
return False
|
|
return (
|
|
raw.endswith("/start")
|
|
or raw.endswith("/stop")
|
|
or raw.endswith("/enable")
|
|
or raw.endswith("/disable")
|
|
or raw.endswith("/deactivate")
|
|
or (verb in {"PUT", "DELETE"} and raw == f"/api/bots/{bot_id}")
|
|
)
|
|
|
|
|
|
def _is_panel_protected_api_path(path: str, method: str = "GET") -> bool:
|
|
raw = str(path or "").strip()
|
|
verb = str(method or "GET").strip().upper()
|
|
if not raw.startswith("/api/"):
|
|
return False
|
|
if raw in {
|
|
"/api/panel/auth/status",
|
|
"/api/panel/auth/login",
|
|
"/api/health",
|
|
"/api/health/cache",
|
|
}:
|
|
return False
|
|
if _is_bot_panel_management_api_path(raw, verb):
|
|
return True
|
|
if _extract_bot_id_from_api_path(raw):
|
|
return False
|
|
return True
|
|
|
|
|
|
class PasswordProtectionMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
path = request.url.path
|
|
method = request.method.upper()
|
|
|
|
if method == "OPTIONS":
|
|
return await call_next(request)
|
|
|
|
bot_id = _extract_bot_id_from_api_path(path)
|
|
if not bot_id:
|
|
if _is_panel_protected_api_path(path, method):
|
|
panel_error = _validate_panel_access_password(_get_supplied_panel_password_http(request))
|
|
if panel_error:
|
|
return JSONResponse(status_code=401, content={"detail": panel_error})
|
|
return await call_next(request)
|
|
|
|
if _is_bot_panel_management_api_path(path, method):
|
|
panel_error = _validate_panel_access_password(_get_supplied_panel_password_http(request))
|
|
if panel_error:
|
|
bot_error = _validate_bot_access_password(bot_id, _get_supplied_bot_access_password_http(request))
|
|
if bot_error:
|
|
return JSONResponse(status_code=401, content={"detail": bot_error})
|
|
|
|
return await call_next(request)
|