from __future__ import annotations import hashlib import re import secrets from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Mapping, Optional from fastapi import Request, Response, WebSocket from sqlmodel import Session, select from core.cache import auth_cache from core.settings import PANEL_ACCESS_PASSWORD from models.auth import AuthLoginLog from models.bot import BotInstance from services.platform_settings_service import get_auth_token_max_active, get_auth_token_ttl_hours PANEL_TOKEN_COOKIE = "nanobot_panel_token" BOT_TOKEN_COOKIE_PREFIX = "nanobot_bot_token_" PANEL_SUBJECT_ID = "panel_admin" AUTH_STORE_SET_TTL_BUFFER_SECONDS = 300 SESSION_TOUCH_INTERVAL_SECONDS = 300 @dataclass(frozen=True) class AuthPrincipal: auth_type: str subject_id: str bot_id: Optional[str] authenticated: bool auth_source: str audit_id: Optional[int] = None def _utcnow() -> datetime: return datetime.utcnow() def _normalize_token(raw: str) -> str: return str(raw or "").strip() def _hash_session_token(raw: str) -> str: return hashlib.sha256(_normalize_token(raw).encode("utf-8")).hexdigest() def _normalize_bot_cookie_name(bot_id: str) -> str: safe_bot_id = re.sub(r"[^a-zA-Z0-9_-]+", "_", str(bot_id or "").strip()) return f"{BOT_TOKEN_COOKIE_PREFIX}{safe_bot_id or 'bot'}" def _token_key(token_hash: str) -> str: return f"token:{str(token_hash or '').strip()}" def _principal_tokens_key(auth_type: str, subject_id: str, bot_id: Optional[str] = None) -> str: normalized_type = str(auth_type or "").strip().lower() or "unknown" normalized_subject = re.sub(r"[^a-zA-Z0-9_.:-]+", "_", str(subject_id or "").strip() or "anonymous") normalized_bot_id = re.sub(r"[^a-zA-Z0-9_.:-]+", "_", str(bot_id or "").strip()) if bot_id else "" return f"principal:{normalized_type}:{normalized_subject}:{normalized_bot_id or '-'}" def _auth_token_ttl_seconds(session: Session) -> int: return max(1, int(get_auth_token_ttl_hours(session))) * 60 * 60 def _auth_token_max_active(session: Session) -> int: return max(1, int(get_auth_token_max_active(session))) def _touch_session(session: Session, row: AuthLoginLog) -> None: now = _utcnow() last_seen = row.last_seen_at or row.created_at or now if (now - last_seen).total_seconds() < SESSION_TOUCH_INTERVAL_SECONDS: return row.last_seen_at = now session.add(row) session.commit() def _summarize_device(user_agent: str) -> str: normalized = str(user_agent or "").strip().lower() if not normalized: return "Unknown Device" browser = "Unknown Browser" if "edg/" in normalized: browser = "Edge" elif "chrome/" in normalized and "edg/" not in normalized: browser = "Chrome" elif "safari/" in normalized and "chrome/" not in normalized: browser = "Safari" elif "firefox/" in normalized: browser = "Firefox" platform = "Desktop" if "iphone" in normalized: platform = "iPhone" elif "ipad" in normalized: platform = "iPad" elif "android" in normalized: platform = "Android" elif "mac os x" in normalized or "macintosh" in normalized: platform = "macOS" elif "windows" in normalized: platform = "Windows" elif "linux" in normalized: platform = "Linux" return f"{platform} / {browser}" def _extract_client_ip(request: Request) -> str: forwarded = str(request.headers.get("x-forwarded-for") or "").strip() if forwarded: return forwarded.split(",")[0].strip()[:120] return str(getattr(request.client, "host", "") or "")[:120] def _get_bearer_token(headers: Mapping[str, Any]) -> str: authorization = str(headers.get("authorization") or headers.get("Authorization") or "").strip() if not authorization.lower().startswith("bearer "): return "" return _normalize_token(authorization[7:]) def _read_panel_token(request: Request) -> str: cookie_token = _normalize_token(request.cookies.get(PANEL_TOKEN_COOKIE) or "") return cookie_token or _get_bearer_token(request.headers) def _read_bot_token(request: Request, bot_id: str) -> str: cookie_token = _normalize_token(request.cookies.get(_normalize_bot_cookie_name(bot_id)) or "") return cookie_token or _get_bearer_token(request.headers) def _read_panel_token_ws(websocket: WebSocket) -> str: cookie_token = _normalize_token(websocket.cookies.get(PANEL_TOKEN_COOKIE) or "") return cookie_token or _get_bearer_token(websocket.headers) def _read_bot_token_ws(websocket: WebSocket, bot_id: str) -> str: cookie_token = _normalize_token(websocket.cookies.get(_normalize_bot_cookie_name(bot_id)) or "") return cookie_token or _get_bearer_token(websocket.headers) def _is_panel_auth_enabled() -> bool: return bool(str(PANEL_ACCESS_PASSWORD or "").strip()) def _get_bot_access_password(session: Session, bot_id: str) -> str: bot = session.get(BotInstance, bot_id) if not bot: return "" return str(bot.access_password or "").strip() def _is_bot_access_enabled(session: Session, bot_id: str) -> bool: return bool(_get_bot_access_password(session, bot_id)) def _resolve_bot_auth_source(session: Session, bot_id: str) -> str: return "bot_password" if _is_bot_access_enabled(session, bot_id) else "bot_public" def _active_token_payload(token_hash: str) -> Optional[dict[str, Any]]: payload = auth_cache.get_json(_token_key(token_hash)) return payload if isinstance(payload, dict) else None def _principal_from_payload(payload: dict[str, Any]) -> tuple[str, str, Optional[str]]: auth_type = str(payload.get("auth_type") or "").strip().lower() subject_id = str(payload.get("subject_id") or "").strip() bot_id = str(payload.get("bot_id") or "").strip() or None return auth_type, subject_id, bot_id def _find_audit_row_by_token_hash(session: Session, token_hash: str) -> Optional[AuthLoginLog]: normalized_hash = str(token_hash or "").strip() if not normalized_hash: return None return session.exec( select(AuthLoginLog).where(AuthLoginLog.token_hash == normalized_hash).limit(1) ).first() def _purge_cached_token(*, token_hash: str, auth_type: str, subject_id: str, bot_id: Optional[str]) -> None: if not auth_cache.enabled: return auth_cache.delete(_token_key(token_hash)) auth_cache.srem(_principal_tokens_key(auth_type, subject_id, bot_id), token_hash) def _active_token_row( session: Session, *, token_hash: str, expected_type: str, bot_id: Optional[str] = None, ) -> Optional[AuthLoginLog]: row = _find_audit_row_by_token_hash(session, token_hash) if row is None: return None normalized_bot_id = str(bot_id or "").strip() or None if row.auth_type != expected_type: return None if expected_type == "bot" and (str(row.bot_id or "").strip() or None) != normalized_bot_id: return None if row.revoked_at is not None: return None if row.expires_at <= _utcnow(): now = _utcnow() row.last_seen_at = now row.revoked_at = now row.revoke_reason = "expired" session.add(row) session.commit() _purge_cached_token( token_hash=token_hash, auth_type=row.auth_type, subject_id=row.subject_id, bot_id=row.bot_id, ) return None return row def _list_active_token_rows( session: Session, *, auth_type: str, subject_id: str, bot_id: Optional[str], ) -> list[AuthLoginLog]: statement = select(AuthLoginLog).where( AuthLoginLog.auth_type == auth_type, AuthLoginLog.subject_id == subject_id, AuthLoginLog.revoked_at.is_(None), ) normalized_bot_id = str(bot_id or "").strip() or None if normalized_bot_id is None: statement = statement.where(AuthLoginLog.bot_id.is_(None)) else: statement = statement.where(AuthLoginLog.bot_id == normalized_bot_id) rows = list(session.exec(statement.order_by(AuthLoginLog.created_at.asc(), AuthLoginLog.id.asc())).all()) now = _utcnow() expired_rows: list[AuthLoginLog] = [] active_rows: list[AuthLoginLog] = [] for row in rows: if row.expires_at <= now: row.last_seen_at = now row.revoked_at = now row.revoke_reason = "expired" session.add(row) expired_rows.append(row) continue active_rows.append(row) if expired_rows: session.commit() for row in expired_rows: _purge_cached_token( token_hash=row.token_hash, auth_type=row.auth_type, subject_id=row.subject_id, bot_id=row.bot_id, ) return active_rows def _mark_audit_revoked(session: Session, token_hash: str, *, reason: str) -> None: row = _find_audit_row_by_token_hash(session, token_hash) if not row: return now = _utcnow() row.last_seen_at = now if row.revoked_at is None: row.revoked_at = now row.revoke_reason = str(reason or "").strip()[:120] or row.revoke_reason session.add(row) session.commit() def _revoke_token_hash(session: Session, token_hash: str, *, reason: str) -> None: normalized_hash = str(token_hash or "").strip() if not normalized_hash: return payload = _active_token_payload(normalized_hash) if payload: auth_type, subject_id, bot_id = _principal_from_payload(payload) auth_cache.delete(_token_key(normalized_hash)) auth_cache.srem(_principal_tokens_key(auth_type, subject_id, bot_id), normalized_hash) _mark_audit_revoked(session, normalized_hash, reason=reason) def _revoke_raw_token(session: Session, raw_token: str, *, reason: str) -> None: token = _normalize_token(raw_token) if not token: return _revoke_token_hash(session, _hash_session_token(token), reason=reason) def _cleanup_principal_set(session: Session, principal_key: str) -> list[tuple[int, str]]: active_rows: list[tuple[int, str]] = [] stale_hashes: list[str] = [] for token_hash in auth_cache.smembers(principal_key): payload = _active_token_payload(token_hash) if not payload: stale_hashes.append(token_hash) continue issued_at_ts = int(payload.get("issued_at_ts") or 0) active_rows.append((issued_at_ts, token_hash)) if stale_hashes: auth_cache.srem(principal_key, *stale_hashes) for stale_hash in stale_hashes: _mark_audit_revoked(session, stale_hash, reason="expired") return sorted(active_rows, key=lambda row: (row[0], row[1])) def _ensure_auth_store_available() -> None: return def _persist_token_payload( session: Session, *, row: AuthLoginLog, raw_token: str, ttl_seconds: int, ) -> None: if not auth_cache.enabled: return token_hash = _hash_session_token(raw_token) payload = { "auth_type": row.auth_type, "subject_id": row.subject_id, "bot_id": row.bot_id, "auth_source": row.auth_source, "issued_at": row.created_at.isoformat() + "Z", "issued_at_ts": int(row.created_at.timestamp()), "expires_at": row.expires_at.isoformat() + "Z", "audit_id": int(row.id or 0), } principal_key = _principal_tokens_key(row.auth_type, row.subject_id, row.bot_id) auth_cache.set_json(_token_key(token_hash), payload, ttl=ttl_seconds) auth_cache.sadd(principal_key, token_hash) auth_cache.expire(principal_key, ttl_seconds + AUTH_STORE_SET_TTL_BUFFER_SECONDS) if not _active_token_payload(token_hash): row.revoked_at = _utcnow() row.revoke_reason = "store_write_failed" session.add(row) session.commit() raise RuntimeError("Failed to persist authentication token") def _enforce_token_limit( session: Session, *, auth_type: str, subject_id: str, bot_id: Optional[str], max_active: int, ) -> None: rows = [ (int(row.created_at.timestamp()), row.token_hash) for row in _list_active_token_rows( session, auth_type=auth_type, subject_id=subject_id, bot_id=bot_id, ) ] overflow = max(0, len(rows) - max_active + 1) if overflow <= 0: return for _, token_hash in rows[:overflow]: _revoke_token_hash(session, token_hash, reason="concurrency_limit") def _create_audit_row( session: Session, *, request: Request, auth_type: str, subject_id: str, bot_id: Optional[str], raw_token: str, expires_at: datetime, auth_source: str, ) -> AuthLoginLog: now = _utcnow() row = AuthLoginLog( auth_type=auth_type, token_hash=_hash_session_token(raw_token), subject_id=subject_id, bot_id=bot_id, auth_source=auth_source, created_at=now, expires_at=expires_at, last_seen_at=now, client_ip=_extract_client_ip(request), user_agent=str(request.headers.get("user-agent") or "")[:500], device_info=_summarize_device(str(request.headers.get("user-agent") or ""))[:255], ) session.add(row) session.commit() session.refresh(row) return row def _create_auth_token( session: Session, *, request: Request, auth_type: str, subject_id: str, bot_id: Optional[str], auth_source: str, ) -> str: _ensure_auth_store_available() ttl_seconds = _auth_token_ttl_seconds(session) max_active = _auth_token_max_active(session) _enforce_token_limit( session, auth_type=auth_type, subject_id=subject_id, bot_id=bot_id, max_active=max_active, ) raw_token = secrets.token_urlsafe(32) row = _create_audit_row( session, request=request, auth_type=auth_type, subject_id=subject_id, bot_id=bot_id, raw_token=raw_token, expires_at=_utcnow() + timedelta(seconds=ttl_seconds), auth_source=auth_source, ) _persist_token_payload(session, row=row, raw_token=raw_token, ttl_seconds=ttl_seconds) return raw_token def create_panel_token(session: Session, request: Request) -> str: revoke_panel_token(session, request, reason="superseded") return _create_auth_token( session, request=request, auth_type="panel", subject_id=PANEL_SUBJECT_ID, bot_id=None, auth_source="panel_password", ) def create_bot_token(session: Session, request: Request, bot_id: str) -> str: normalized_bot_id = str(bot_id or "").strip() revoke_bot_token(session, request, normalized_bot_id, reason="superseded") return _create_auth_token( session, request=request, auth_type="bot", subject_id=normalized_bot_id, bot_id=normalized_bot_id, auth_source=_resolve_bot_auth_source(session, normalized_bot_id), ) def revoke_panel_token(session: Session, request: Request, reason: str = "logout") -> None: _revoke_raw_token(session, _read_panel_token(request), reason=reason) def revoke_bot_token(session: Session, request: Request, bot_id: str, reason: str = "logout") -> None: _revoke_raw_token(session, _read_bot_token(request, bot_id), reason=reason) def _set_cookie(response: Response, request: Request, name: str, raw_token: str, max_age: int) -> None: response.set_cookie( name, raw_token, max_age=max_age, httponly=True, samesite="lax", secure=str(request.url.scheme).lower() == "https", path="/", ) def set_panel_token_cookie(response: Response, request: Request, raw_token: str, session: Session) -> None: _set_cookie(response, request, PANEL_TOKEN_COOKIE, raw_token, _auth_token_ttl_seconds(session)) def set_bot_token_cookie(response: Response, request: Request, bot_id: str, raw_token: str, session: Session) -> None: _set_cookie(response, request, _normalize_bot_cookie_name(bot_id), raw_token, _auth_token_ttl_seconds(session)) def clear_panel_token_cookie(response: Response) -> None: response.delete_cookie(PANEL_TOKEN_COOKIE, path="/") def clear_bot_token_cookie(response: Response, bot_id: str) -> None: response.delete_cookie(_normalize_bot_cookie_name(bot_id), path="/") def _resolve_token_auth( session: Session, *, raw_token: str, expected_type: str, bot_id: Optional[str] = None, ) -> AuthPrincipal: token = _normalize_token(raw_token) normalized_bot_id = str(bot_id or "").strip() or None if not token: return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing") token_hash = _hash_session_token(token) payload = _active_token_payload(token_hash) if auth_cache.enabled else None if not payload: row = _active_token_row(session, token_hash=token_hash, expected_type=expected_type, bot_id=normalized_bot_id) if row is None: return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing") _touch_session(session, row) return AuthPrincipal(expected_type, row.subject_id, row.bot_id, True, f"{expected_type}_token", row.id) auth_type, subject_id, payload_bot_id = _principal_from_payload(payload) if auth_type != expected_type or (expected_type == "bot" and payload_bot_id != normalized_bot_id): row = _active_token_row(session, token_hash=token_hash, expected_type=expected_type, bot_id=normalized_bot_id) if row is None: return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing") _touch_session(session, row) return AuthPrincipal(expected_type, row.subject_id, row.bot_id, True, f"{expected_type}_token", row.id) expires_at_raw = str(payload.get("expires_at") or "").strip() if expires_at_raw: try: expires_at = datetime.fromisoformat(expires_at_raw.replace("Z", "")) if expires_at <= _utcnow(): _revoke_token_hash(session, token_hash, reason="expired") return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing") except Exception: pass row_id = int(payload.get("audit_id") or 0) or None if row_id is not None: row = session.get(AuthLoginLog, row_id) if row is None or row.revoked_at is not None: fallback_row = _active_token_row( session, token_hash=token_hash, expected_type=expected_type, bot_id=normalized_bot_id, ) if fallback_row is None: return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing") _touch_session(session, fallback_row) return AuthPrincipal( expected_type, fallback_row.subject_id, fallback_row.bot_id, True, f"{expected_type}_token", fallback_row.id, ) if row.expires_at <= _utcnow(): _revoke_token_hash(session, token_hash, reason="expired") return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing") _touch_session(session, row) return AuthPrincipal(expected_type, subject_id, payload_bot_id, True, f"{expected_type}_token", row_id) def resolve_panel_request_auth(session: Session, request: Request) -> AuthPrincipal: if not _is_panel_auth_enabled(): return AuthPrincipal("panel", PANEL_SUBJECT_ID, None, True, "unprotected") return _resolve_token_auth(session, raw_token=_read_panel_token(request), expected_type="panel") def resolve_bot_request_auth(session: Session, request: Request, bot_id: str) -> AuthPrincipal: normalized_bot_id = str(bot_id or "").strip() if not normalized_bot_id: return AuthPrincipal("bot", "", None, False, "missing") return _resolve_token_auth( session, raw_token=_read_bot_token(request, normalized_bot_id), expected_type="bot", bot_id=normalized_bot_id, ) def resolve_panel_websocket_auth(session: Session, websocket: WebSocket) -> AuthPrincipal: if not _is_panel_auth_enabled(): return AuthPrincipal("panel", PANEL_SUBJECT_ID, None, True, "unprotected") return _resolve_token_auth(session, raw_token=_read_panel_token_ws(websocket), expected_type="panel") def resolve_bot_websocket_auth(session: Session, websocket: WebSocket, bot_id: str) -> AuthPrincipal: normalized_bot_id = str(bot_id or "").strip() if not normalized_bot_id: return AuthPrincipal("bot", "", None, False, "missing") return _resolve_token_auth( session, raw_token=_read_bot_token_ws(websocket, normalized_bot_id), expected_type="bot", bot_id=normalized_bot_id, )