from __future__ import annotations import hashlib import re import secrets from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Mapping, Optional from fastapi import Request, Response, WebSocket from sqlmodel import Session, select from core.cache import auth_cache from core.settings import PANEL_ACCESS_PASSWORD from models.auth import AuthLoginLog from models.bot import BotInstance from services.platform_settings_service import get_auth_token_max_active, get_auth_token_ttl_hours PANEL_TOKEN_COOKIE = "nanobot_panel_token" BOT_TOKEN_COOKIE_PREFIX = "nanobot_bot_token_" PANEL_SUBJECT_ID = "panel_admin" AUTH_STORE_SET_TTL_BUFFER_SECONDS = 300 SESSION_TOUCH_INTERVAL_SECONDS = 300 @dataclass(frozen=True) class AuthPrincipal: auth_type: str subject_id: str bot_id: Optional[str] authenticated: bool auth_source: str audit_id: Optional[int] = None def _utcnow() -> datetime: return datetime.utcnow() def _normalize_token(raw: str) -> str: return str(raw or "").strip() def _hash_session_token(raw: str) -> str: return hashlib.sha256(_normalize_token(raw).encode("utf-8")).hexdigest() def _normalize_bot_cookie_name(bot_id: str) -> str: safe_bot_id = re.sub(r"[^a-zA-Z0-9_-]+", "_", str(bot_id or "").strip()) return f"{BOT_TOKEN_COOKIE_PREFIX}{safe_bot_id or 'bot'}" def _token_key(token_hash: str) -> str: return f"token:{str(token_hash or '').strip()}" def _principal_tokens_key(auth_type: str, subject_id: str, bot_id: Optional[str] = None) -> str: normalized_type = str(auth_type or "").strip().lower() or "unknown" normalized_subject = re.sub(r"[^a-zA-Z0-9_.:-]+", "_", str(subject_id or "").strip() or "anonymous") normalized_bot_id = re.sub(r"[^a-zA-Z0-9_.:-]+", "_", str(bot_id or "").strip()) if bot_id else "" return f"principal:{normalized_type}:{normalized_subject}:{normalized_bot_id or '-'}" def _auth_token_ttl_seconds(session: Session) -> int: return max(1, int(get_auth_token_ttl_hours(session))) * 60 * 60 def _auth_token_max_active(session: Session) -> int: return max(1, int(get_auth_token_max_active(session))) def _touch_session(session: Session, row: AuthLoginLog) -> None: now = _utcnow() last_seen = row.last_seen_at or row.created_at or now if (now - last_seen).total_seconds() < SESSION_TOUCH_INTERVAL_SECONDS: return row.last_seen_at = now session.add(row) session.commit() def _summarize_device(user_agent: str) -> str: normalized = str(user_agent or "").strip().lower() if not normalized: return "Unknown Device" browser = "Unknown Browser" if "edg/" in normalized: browser = "Edge" elif "chrome/" in normalized and "edg/" not in normalized: browser = "Chrome" elif "safari/" in normalized and "chrome/" not in normalized: browser = "Safari" elif "firefox/" in normalized: browser = "Firefox" platform = "Desktop" if "iphone" in normalized: platform = "iPhone" elif "ipad" in normalized: platform = "iPad" elif "android" in normalized: platform = "Android" elif "mac os x" in normalized or "macintosh" in normalized: platform = "macOS" elif "windows" in normalized: platform = "Windows" elif "linux" in normalized: platform = "Linux" return f"{platform} / {browser}" def _extract_client_ip(request: Request) -> str: forwarded = str(request.headers.get("x-forwarded-for") or "").strip() if forwarded: return forwarded.split(",")[0].strip()[:120] return str(getattr(request.client, "host", "") or "")[:120] def _get_bearer_token(headers: Mapping[str, Any]) -> str: authorization = str(headers.get("authorization") or headers.get("Authorization") or "").strip() if not authorization.lower().startswith("bearer "): return "" return _normalize_token(authorization[7:]) def _read_panel_token(request: Request) -> str: cookie_token = _normalize_token(request.cookies.get(PANEL_TOKEN_COOKIE) or "") return cookie_token or _get_bearer_token(request.headers) def _read_bot_token(request: Request, bot_id: str) -> str: cookie_token = _normalize_token(request.cookies.get(_normalize_bot_cookie_name(bot_id)) or "") return cookie_token or _get_bearer_token(request.headers) def _read_panel_token_ws(websocket: WebSocket) -> str: cookie_token = _normalize_token(websocket.cookies.get(PANEL_TOKEN_COOKIE) or "") return cookie_token or _get_bearer_token(websocket.headers) def _read_bot_token_ws(websocket: WebSocket, bot_id: str) -> str: cookie_token = _normalize_token(websocket.cookies.get(_normalize_bot_cookie_name(bot_id)) or "") return cookie_token or _get_bearer_token(websocket.headers) def _is_panel_auth_enabled() -> bool: return bool(str(PANEL_ACCESS_PASSWORD or "").strip()) def _get_bot_access_password(session: Session, bot_id: str) -> str: bot = session.get(BotInstance, bot_id) if not bot: return "" return str(bot.access_password or "").strip() def _is_bot_access_enabled(session: Session, bot_id: str) -> bool: return bool(_get_bot_access_password(session, bot_id)) def _resolve_bot_auth_source(session: Session, bot_id: str) -> str: return "bot_password" if _is_bot_access_enabled(session, bot_id) else "bot_public" def _active_token_payload(token_hash: str) -> Optional[dict[str, Any]]: payload = auth_cache.get_json(_token_key(token_hash)) return payload if isinstance(payload, dict) else None def _principal_from_payload(payload: dict[str, Any]) -> tuple[str, str, Optional[str]]: auth_type = str(payload.get("auth_type") or "").strip().lower() subject_id = str(payload.get("subject_id") or "").strip() bot_id = str(payload.get("bot_id") or "").strip() or None return auth_type, subject_id, bot_id def _mark_audit_revoked(session: Session, token_hash: str, *, reason: str) -> None: row = session.exec( select(AuthLoginLog).where(AuthLoginLog.token_hash == token_hash).limit(1) ).first() if not row: return now = _utcnow() row.last_seen_at = now if row.revoked_at is None: row.revoked_at = now row.revoke_reason = str(reason or "").strip()[:120] or row.revoke_reason session.add(row) session.commit() def _revoke_token_hash(session: Session, token_hash: str, *, reason: str) -> None: normalized_hash = str(token_hash or "").strip() if not normalized_hash: return payload = _active_token_payload(normalized_hash) if payload: auth_type, subject_id, bot_id = _principal_from_payload(payload) auth_cache.delete(_token_key(normalized_hash)) auth_cache.srem(_principal_tokens_key(auth_type, subject_id, bot_id), normalized_hash) _mark_audit_revoked(session, normalized_hash, reason=reason) def _revoke_raw_token(session: Session, raw_token: str, *, reason: str) -> None: token = _normalize_token(raw_token) if not token: return _revoke_token_hash(session, _hash_session_token(token), reason=reason) def _cleanup_principal_set(session: Session, principal_key: str) -> list[tuple[int, str]]: active_rows: list[tuple[int, str]] = [] stale_hashes: list[str] = [] for token_hash in auth_cache.smembers(principal_key): payload = _active_token_payload(token_hash) if not payload: stale_hashes.append(token_hash) continue issued_at_ts = int(payload.get("issued_at_ts") or 0) active_rows.append((issued_at_ts, token_hash)) if stale_hashes: auth_cache.srem(principal_key, *stale_hashes) for stale_hash in stale_hashes: _mark_audit_revoked(session, stale_hash, reason="expired") return sorted(active_rows, key=lambda row: (row[0], row[1])) def _ensure_auth_store_available() -> None: if auth_cache.enabled: return raise RuntimeError("Redis authentication store is unavailable") def _persist_token_payload( session: Session, *, row: AuthLoginLog, raw_token: str, ttl_seconds: int, ) -> None: token_hash = _hash_session_token(raw_token) payload = { "auth_type": row.auth_type, "subject_id": row.subject_id, "bot_id": row.bot_id, "auth_source": row.auth_source, "issued_at": row.created_at.isoformat() + "Z", "issued_at_ts": int(row.created_at.timestamp()), "expires_at": row.expires_at.isoformat() + "Z", "audit_id": int(row.id or 0), } principal_key = _principal_tokens_key(row.auth_type, row.subject_id, row.bot_id) auth_cache.set_json(_token_key(token_hash), payload, ttl=ttl_seconds) auth_cache.sadd(principal_key, token_hash) auth_cache.expire(principal_key, ttl_seconds + AUTH_STORE_SET_TTL_BUFFER_SECONDS) if not _active_token_payload(token_hash): row.revoked_at = _utcnow() row.revoke_reason = "store_write_failed" session.add(row) session.commit() raise RuntimeError("Failed to persist authentication token") def _enforce_token_limit( session: Session, *, auth_type: str, subject_id: str, bot_id: Optional[str], max_active: int, ) -> None: principal_key = _principal_tokens_key(auth_type, subject_id, bot_id) rows = _cleanup_principal_set(session, principal_key) overflow = max(0, len(rows) - max_active + 1) if overflow <= 0: return for _, token_hash in rows[:overflow]: _revoke_token_hash(session, token_hash, reason="concurrency_limit") def _create_audit_row( session: Session, *, request: Request, auth_type: str, subject_id: str, bot_id: Optional[str], raw_token: str, expires_at: datetime, auth_source: str, ) -> AuthLoginLog: now = _utcnow() row = AuthLoginLog( auth_type=auth_type, token_hash=_hash_session_token(raw_token), subject_id=subject_id, bot_id=bot_id, auth_source=auth_source, created_at=now, expires_at=expires_at, last_seen_at=now, client_ip=_extract_client_ip(request), user_agent=str(request.headers.get("user-agent") or "")[:500], device_info=_summarize_device(str(request.headers.get("user-agent") or ""))[:255], ) session.add(row) session.commit() session.refresh(row) return row def _create_auth_token( session: Session, *, request: Request, auth_type: str, subject_id: str, bot_id: Optional[str], auth_source: str, ) -> str: _ensure_auth_store_available() ttl_seconds = _auth_token_ttl_seconds(session) max_active = _auth_token_max_active(session) _enforce_token_limit( session, auth_type=auth_type, subject_id=subject_id, bot_id=bot_id, max_active=max_active, ) raw_token = secrets.token_urlsafe(32) row = _create_audit_row( session, request=request, auth_type=auth_type, subject_id=subject_id, bot_id=bot_id, raw_token=raw_token, expires_at=_utcnow() + timedelta(seconds=ttl_seconds), auth_source=auth_source, ) _persist_token_payload(session, row=row, raw_token=raw_token, ttl_seconds=ttl_seconds) return raw_token def create_panel_token(session: Session, request: Request) -> str: revoke_panel_token(session, request, reason="superseded") return _create_auth_token( session, request=request, auth_type="panel", subject_id=PANEL_SUBJECT_ID, bot_id=None, auth_source="panel_password", ) def create_bot_token(session: Session, request: Request, bot_id: str) -> str: normalized_bot_id = str(bot_id or "").strip() revoke_bot_token(session, request, normalized_bot_id, reason="superseded") return _create_auth_token( session, request=request, auth_type="bot", subject_id=normalized_bot_id, bot_id=normalized_bot_id, auth_source=_resolve_bot_auth_source(session, normalized_bot_id), ) def revoke_panel_token(session: Session, request: Request, reason: str = "logout") -> None: _revoke_raw_token(session, _read_panel_token(request), reason=reason) def revoke_bot_token(session: Session, request: Request, bot_id: str, reason: str = "logout") -> None: _revoke_raw_token(session, _read_bot_token(request, bot_id), reason=reason) def _set_cookie(response: Response, request: Request, name: str, raw_token: str, max_age: int) -> None: response.set_cookie( name, raw_token, max_age=max_age, httponly=True, samesite="lax", secure=str(request.url.scheme).lower() == "https", path="/", ) def set_panel_token_cookie(response: Response, request: Request, raw_token: str, session: Session) -> None: _set_cookie(response, request, PANEL_TOKEN_COOKIE, raw_token, _auth_token_ttl_seconds(session)) def set_bot_token_cookie(response: Response, request: Request, bot_id: str, raw_token: str, session: Session) -> None: _set_cookie(response, request, _normalize_bot_cookie_name(bot_id), raw_token, _auth_token_ttl_seconds(session)) def clear_panel_token_cookie(response: Response) -> None: response.delete_cookie(PANEL_TOKEN_COOKIE, path="/") def clear_bot_token_cookie(response: Response, bot_id: str) -> None: response.delete_cookie(_normalize_bot_cookie_name(bot_id), path="/") def _resolve_token_auth( session: Session, *, raw_token: str, expected_type: str, bot_id: Optional[str] = None, ) -> AuthPrincipal: token = _normalize_token(raw_token) normalized_bot_id = str(bot_id or "").strip() or None if not token: return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing") payload = _active_token_payload(_hash_session_token(token)) if not payload: return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing") auth_type, subject_id, payload_bot_id = _principal_from_payload(payload) if auth_type != expected_type: return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing") if expected_type == "bot" and payload_bot_id != normalized_bot_id: return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing") expires_at_raw = str(payload.get("expires_at") or "").strip() if expires_at_raw: try: expires_at = datetime.fromisoformat(expires_at_raw.replace("Z", "")) if expires_at <= _utcnow(): _revoke_token_hash(session, _hash_session_token(token), reason="expired") return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing") except Exception: pass row_id = int(payload.get("audit_id") or 0) or None if row_id is not None: row = session.get(AuthLoginLog, row_id) if row is not None: _touch_session(session, row) return AuthPrincipal(expected_type, subject_id, payload_bot_id, True, f"{expected_type}_token", row_id) def resolve_panel_request_auth(session: Session, request: Request) -> AuthPrincipal: if not _is_panel_auth_enabled(): return AuthPrincipal("panel", PANEL_SUBJECT_ID, None, True, "unprotected") return _resolve_token_auth(session, raw_token=_read_panel_token(request), expected_type="panel") def resolve_bot_request_auth(session: Session, request: Request, bot_id: str) -> AuthPrincipal: normalized_bot_id = str(bot_id or "").strip() if not normalized_bot_id: return AuthPrincipal("bot", "", None, False, "missing") return _resolve_token_auth( session, raw_token=_read_bot_token(request, normalized_bot_id), expected_type="bot", bot_id=normalized_bot_id, ) def resolve_panel_websocket_auth(session: Session, websocket: WebSocket) -> AuthPrincipal: if not _is_panel_auth_enabled(): return AuthPrincipal("panel", PANEL_SUBJECT_ID, None, True, "unprotected") return _resolve_token_auth(session, raw_token=_read_panel_token_ws(websocket), expected_type="panel") def resolve_bot_websocket_auth(session: Session, websocket: WebSocket, bot_id: str) -> AuthPrincipal: normalized_bot_id = str(bot_id or "").strip() if not normalized_bot_id: return AuthPrincipal("bot", "", None, False, "missing") return _resolve_token_auth( session, raw_token=_read_bot_token_ws(websocket, normalized_bot_id), expected_type="bot", bot_id=normalized_bot_id, )