dashboard-nanobot/backend/services/platform_auth_service.py

605 lines
21 KiB
Python

from __future__ import annotations
import hashlib
import re
import secrets
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Mapping, Optional
from fastapi import Request, Response, WebSocket
from sqlmodel import Session, select
from core.cache import auth_cache
from core.settings import PANEL_ACCESS_PASSWORD
from models.auth import AuthLoginLog
from models.bot import BotInstance
from services.platform_settings_service import get_auth_token_max_active, get_auth_token_ttl_hours
PANEL_TOKEN_COOKIE = "nanobot_panel_token"
BOT_TOKEN_COOKIE_PREFIX = "nanobot_bot_token_"
PANEL_SUBJECT_ID = "panel_admin"
AUTH_STORE_SET_TTL_BUFFER_SECONDS = 300
SESSION_TOUCH_INTERVAL_SECONDS = 300
@dataclass(frozen=True)
class AuthPrincipal:
auth_type: str
subject_id: str
bot_id: Optional[str]
authenticated: bool
auth_source: str
audit_id: Optional[int] = None
def _utcnow() -> datetime:
return datetime.utcnow()
def _normalize_token(raw: str) -> str:
return str(raw or "").strip()
def _hash_session_token(raw: str) -> str:
return hashlib.sha256(_normalize_token(raw).encode("utf-8")).hexdigest()
def _normalize_bot_cookie_name(bot_id: str) -> str:
safe_bot_id = re.sub(r"[^a-zA-Z0-9_-]+", "_", str(bot_id or "").strip())
return f"{BOT_TOKEN_COOKIE_PREFIX}{safe_bot_id or 'bot'}"
def _token_key(token_hash: str) -> str:
return f"token:{str(token_hash or '').strip()}"
def _principal_tokens_key(auth_type: str, subject_id: str, bot_id: Optional[str] = None) -> str:
normalized_type = str(auth_type or "").strip().lower() or "unknown"
normalized_subject = re.sub(r"[^a-zA-Z0-9_.:-]+", "_", str(subject_id or "").strip() or "anonymous")
normalized_bot_id = re.sub(r"[^a-zA-Z0-9_.:-]+", "_", str(bot_id or "").strip()) if bot_id else ""
return f"principal:{normalized_type}:{normalized_subject}:{normalized_bot_id or '-'}"
def _auth_token_ttl_seconds(session: Session) -> int:
return max(1, int(get_auth_token_ttl_hours(session))) * 60 * 60
def _auth_token_max_active(session: Session) -> int:
return max(1, int(get_auth_token_max_active(session)))
def _touch_session(session: Session, row: AuthLoginLog) -> None:
now = _utcnow()
last_seen = row.last_seen_at or row.created_at or now
if (now - last_seen).total_seconds() < SESSION_TOUCH_INTERVAL_SECONDS:
return
row.last_seen_at = now
session.add(row)
session.commit()
def _summarize_device(user_agent: str) -> str:
normalized = str(user_agent or "").strip().lower()
if not normalized:
return "Unknown Device"
browser = "Unknown Browser"
if "edg/" in normalized:
browser = "Edge"
elif "chrome/" in normalized and "edg/" not in normalized:
browser = "Chrome"
elif "safari/" in normalized and "chrome/" not in normalized:
browser = "Safari"
elif "firefox/" in normalized:
browser = "Firefox"
platform = "Desktop"
if "iphone" in normalized:
platform = "iPhone"
elif "ipad" in normalized:
platform = "iPad"
elif "android" in normalized:
platform = "Android"
elif "mac os x" in normalized or "macintosh" in normalized:
platform = "macOS"
elif "windows" in normalized:
platform = "Windows"
elif "linux" in normalized:
platform = "Linux"
return f"{platform} / {browser}"
def _extract_client_ip(request: Request) -> str:
forwarded = str(request.headers.get("x-forwarded-for") or "").strip()
if forwarded:
return forwarded.split(",")[0].strip()[:120]
return str(getattr(request.client, "host", "") or "")[:120]
def _get_bearer_token(headers: Mapping[str, Any]) -> str:
authorization = str(headers.get("authorization") or headers.get("Authorization") or "").strip()
if not authorization.lower().startswith("bearer "):
return ""
return _normalize_token(authorization[7:])
def _read_panel_token(request: Request) -> str:
cookie_token = _normalize_token(request.cookies.get(PANEL_TOKEN_COOKIE) or "")
return cookie_token or _get_bearer_token(request.headers)
def _read_bot_token(request: Request, bot_id: str) -> str:
cookie_token = _normalize_token(request.cookies.get(_normalize_bot_cookie_name(bot_id)) or "")
return cookie_token or _get_bearer_token(request.headers)
def _read_panel_token_ws(websocket: WebSocket) -> str:
cookie_token = _normalize_token(websocket.cookies.get(PANEL_TOKEN_COOKIE) or "")
return cookie_token or _get_bearer_token(websocket.headers)
def _read_bot_token_ws(websocket: WebSocket, bot_id: str) -> str:
cookie_token = _normalize_token(websocket.cookies.get(_normalize_bot_cookie_name(bot_id)) or "")
return cookie_token or _get_bearer_token(websocket.headers)
def _is_panel_auth_enabled() -> bool:
return bool(str(PANEL_ACCESS_PASSWORD or "").strip())
def _get_bot_access_password(session: Session, bot_id: str) -> str:
bot = session.get(BotInstance, bot_id)
if not bot:
return ""
return str(bot.access_password or "").strip()
def _is_bot_access_enabled(session: Session, bot_id: str) -> bool:
return bool(_get_bot_access_password(session, bot_id))
def _resolve_bot_auth_source(session: Session, bot_id: str) -> str:
return "bot_password" if _is_bot_access_enabled(session, bot_id) else "bot_public"
def _active_token_payload(token_hash: str) -> Optional[dict[str, Any]]:
payload = auth_cache.get_json(_token_key(token_hash))
return payload if isinstance(payload, dict) else None
def _principal_from_payload(payload: dict[str, Any]) -> tuple[str, str, Optional[str]]:
auth_type = str(payload.get("auth_type") or "").strip().lower()
subject_id = str(payload.get("subject_id") or "").strip()
bot_id = str(payload.get("bot_id") or "").strip() or None
return auth_type, subject_id, bot_id
def _find_audit_row_by_token_hash(session: Session, token_hash: str) -> Optional[AuthLoginLog]:
normalized_hash = str(token_hash or "").strip()
if not normalized_hash:
return None
return session.exec(
select(AuthLoginLog).where(AuthLoginLog.token_hash == normalized_hash).limit(1)
).first()
def _purge_cached_token(*, token_hash: str, auth_type: str, subject_id: str, bot_id: Optional[str]) -> None:
if not auth_cache.enabled:
return
auth_cache.delete(_token_key(token_hash))
auth_cache.srem(_principal_tokens_key(auth_type, subject_id, bot_id), token_hash)
def _active_token_row(
session: Session,
*,
token_hash: str,
expected_type: str,
bot_id: Optional[str] = None,
) -> Optional[AuthLoginLog]:
row = _find_audit_row_by_token_hash(session, token_hash)
if row is None:
return None
normalized_bot_id = str(bot_id or "").strip() or None
if row.auth_type != expected_type:
return None
if expected_type == "bot" and (str(row.bot_id or "").strip() or None) != normalized_bot_id:
return None
if row.revoked_at is not None:
return None
if row.expires_at <= _utcnow():
now = _utcnow()
row.last_seen_at = now
row.revoked_at = now
row.revoke_reason = "expired"
session.add(row)
session.commit()
_purge_cached_token(
token_hash=token_hash,
auth_type=row.auth_type,
subject_id=row.subject_id,
bot_id=row.bot_id,
)
return None
return row
def _list_active_token_rows(
session: Session,
*,
auth_type: str,
subject_id: str,
bot_id: Optional[str],
) -> list[AuthLoginLog]:
statement = select(AuthLoginLog).where(
AuthLoginLog.auth_type == auth_type,
AuthLoginLog.subject_id == subject_id,
AuthLoginLog.revoked_at.is_(None),
)
normalized_bot_id = str(bot_id or "").strip() or None
if normalized_bot_id is None:
statement = statement.where(AuthLoginLog.bot_id.is_(None))
else:
statement = statement.where(AuthLoginLog.bot_id == normalized_bot_id)
rows = list(session.exec(statement.order_by(AuthLoginLog.created_at.asc(), AuthLoginLog.id.asc())).all())
now = _utcnow()
expired_rows: list[AuthLoginLog] = []
active_rows: list[AuthLoginLog] = []
for row in rows:
if row.expires_at <= now:
row.last_seen_at = now
row.revoked_at = now
row.revoke_reason = "expired"
session.add(row)
expired_rows.append(row)
continue
active_rows.append(row)
if expired_rows:
session.commit()
for row in expired_rows:
_purge_cached_token(
token_hash=row.token_hash,
auth_type=row.auth_type,
subject_id=row.subject_id,
bot_id=row.bot_id,
)
return active_rows
def _mark_audit_revoked(session: Session, token_hash: str, *, reason: str) -> None:
row = _find_audit_row_by_token_hash(session, token_hash)
if not row:
return
now = _utcnow()
row.last_seen_at = now
if row.revoked_at is None:
row.revoked_at = now
row.revoke_reason = str(reason or "").strip()[:120] or row.revoke_reason
session.add(row)
session.commit()
def _revoke_token_hash(session: Session, token_hash: str, *, reason: str) -> None:
normalized_hash = str(token_hash or "").strip()
if not normalized_hash:
return
payload = _active_token_payload(normalized_hash)
if payload:
auth_type, subject_id, bot_id = _principal_from_payload(payload)
auth_cache.delete(_token_key(normalized_hash))
auth_cache.srem(_principal_tokens_key(auth_type, subject_id, bot_id), normalized_hash)
_mark_audit_revoked(session, normalized_hash, reason=reason)
def _revoke_raw_token(session: Session, raw_token: str, *, reason: str) -> None:
token = _normalize_token(raw_token)
if not token:
return
_revoke_token_hash(session, _hash_session_token(token), reason=reason)
def _cleanup_principal_set(session: Session, principal_key: str) -> list[tuple[int, str]]:
active_rows: list[tuple[int, str]] = []
stale_hashes: list[str] = []
for token_hash in auth_cache.smembers(principal_key):
payload = _active_token_payload(token_hash)
if not payload:
stale_hashes.append(token_hash)
continue
issued_at_ts = int(payload.get("issued_at_ts") or 0)
active_rows.append((issued_at_ts, token_hash))
if stale_hashes:
auth_cache.srem(principal_key, *stale_hashes)
for stale_hash in stale_hashes:
_mark_audit_revoked(session, stale_hash, reason="expired")
return sorted(active_rows, key=lambda row: (row[0], row[1]))
def _ensure_auth_store_available() -> None:
return
def _persist_token_payload(
session: Session,
*,
row: AuthLoginLog,
raw_token: str,
ttl_seconds: int,
) -> None:
if not auth_cache.enabled:
return
token_hash = _hash_session_token(raw_token)
payload = {
"auth_type": row.auth_type,
"subject_id": row.subject_id,
"bot_id": row.bot_id,
"auth_source": row.auth_source,
"issued_at": row.created_at.isoformat() + "Z",
"issued_at_ts": int(row.created_at.timestamp()),
"expires_at": row.expires_at.isoformat() + "Z",
"audit_id": int(row.id or 0),
}
principal_key = _principal_tokens_key(row.auth_type, row.subject_id, row.bot_id)
auth_cache.set_json(_token_key(token_hash), payload, ttl=ttl_seconds)
auth_cache.sadd(principal_key, token_hash)
auth_cache.expire(principal_key, ttl_seconds + AUTH_STORE_SET_TTL_BUFFER_SECONDS)
if not _active_token_payload(token_hash):
row.revoked_at = _utcnow()
row.revoke_reason = "store_write_failed"
session.add(row)
session.commit()
raise RuntimeError("Failed to persist authentication token")
def _enforce_token_limit(
session: Session,
*,
auth_type: str,
subject_id: str,
bot_id: Optional[str],
max_active: int,
) -> None:
rows = [
(int(row.created_at.timestamp()), row.token_hash)
for row in _list_active_token_rows(
session,
auth_type=auth_type,
subject_id=subject_id,
bot_id=bot_id,
)
]
overflow = max(0, len(rows) - max_active + 1)
if overflow <= 0:
return
for _, token_hash in rows[:overflow]:
_revoke_token_hash(session, token_hash, reason="concurrency_limit")
def _create_audit_row(
session: Session,
*,
request: Request,
auth_type: str,
subject_id: str,
bot_id: Optional[str],
raw_token: str,
expires_at: datetime,
auth_source: str,
) -> AuthLoginLog:
now = _utcnow()
row = AuthLoginLog(
auth_type=auth_type,
token_hash=_hash_session_token(raw_token),
subject_id=subject_id,
bot_id=bot_id,
auth_source=auth_source,
created_at=now,
expires_at=expires_at,
last_seen_at=now,
client_ip=_extract_client_ip(request),
user_agent=str(request.headers.get("user-agent") or "")[:500],
device_info=_summarize_device(str(request.headers.get("user-agent") or ""))[:255],
)
session.add(row)
session.commit()
session.refresh(row)
return row
def _create_auth_token(
session: Session,
*,
request: Request,
auth_type: str,
subject_id: str,
bot_id: Optional[str],
auth_source: str,
) -> str:
_ensure_auth_store_available()
ttl_seconds = _auth_token_ttl_seconds(session)
max_active = _auth_token_max_active(session)
_enforce_token_limit(
session,
auth_type=auth_type,
subject_id=subject_id,
bot_id=bot_id,
max_active=max_active,
)
raw_token = secrets.token_urlsafe(32)
row = _create_audit_row(
session,
request=request,
auth_type=auth_type,
subject_id=subject_id,
bot_id=bot_id,
raw_token=raw_token,
expires_at=_utcnow() + timedelta(seconds=ttl_seconds),
auth_source=auth_source,
)
_persist_token_payload(session, row=row, raw_token=raw_token, ttl_seconds=ttl_seconds)
return raw_token
def create_panel_token(session: Session, request: Request) -> str:
revoke_panel_token(session, request, reason="superseded")
return _create_auth_token(
session,
request=request,
auth_type="panel",
subject_id=PANEL_SUBJECT_ID,
bot_id=None,
auth_source="panel_password",
)
def create_bot_token(session: Session, request: Request, bot_id: str) -> str:
normalized_bot_id = str(bot_id or "").strip()
revoke_bot_token(session, request, normalized_bot_id, reason="superseded")
return _create_auth_token(
session,
request=request,
auth_type="bot",
subject_id=normalized_bot_id,
bot_id=normalized_bot_id,
auth_source=_resolve_bot_auth_source(session, normalized_bot_id),
)
def revoke_panel_token(session: Session, request: Request, reason: str = "logout") -> None:
_revoke_raw_token(session, _read_panel_token(request), reason=reason)
def revoke_bot_token(session: Session, request: Request, bot_id: str, reason: str = "logout") -> None:
_revoke_raw_token(session, _read_bot_token(request, bot_id), reason=reason)
def _set_cookie(response: Response, request: Request, name: str, raw_token: str, max_age: int) -> None:
response.set_cookie(
name,
raw_token,
max_age=max_age,
httponly=True,
samesite="lax",
secure=str(request.url.scheme).lower() == "https",
path="/",
)
def set_panel_token_cookie(response: Response, request: Request, raw_token: str, session: Session) -> None:
_set_cookie(response, request, PANEL_TOKEN_COOKIE, raw_token, _auth_token_ttl_seconds(session))
def set_bot_token_cookie(response: Response, request: Request, bot_id: str, raw_token: str, session: Session) -> None:
_set_cookie(response, request, _normalize_bot_cookie_name(bot_id), raw_token, _auth_token_ttl_seconds(session))
def clear_panel_token_cookie(response: Response) -> None:
response.delete_cookie(PANEL_TOKEN_COOKIE, path="/")
def clear_bot_token_cookie(response: Response, bot_id: str) -> None:
response.delete_cookie(_normalize_bot_cookie_name(bot_id), path="/")
def _resolve_token_auth(
session: Session,
*,
raw_token: str,
expected_type: str,
bot_id: Optional[str] = None,
) -> AuthPrincipal:
token = _normalize_token(raw_token)
normalized_bot_id = str(bot_id or "").strip() or None
if not token:
return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing")
token_hash = _hash_session_token(token)
payload = _active_token_payload(token_hash) if auth_cache.enabled else None
if not payload:
row = _active_token_row(session, token_hash=token_hash, expected_type=expected_type, bot_id=normalized_bot_id)
if row is None:
return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing")
_touch_session(session, row)
return AuthPrincipal(expected_type, row.subject_id, row.bot_id, True, f"{expected_type}_token", row.id)
auth_type, subject_id, payload_bot_id = _principal_from_payload(payload)
if auth_type != expected_type or (expected_type == "bot" and payload_bot_id != normalized_bot_id):
row = _active_token_row(session, token_hash=token_hash, expected_type=expected_type, bot_id=normalized_bot_id)
if row is None:
return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing")
_touch_session(session, row)
return AuthPrincipal(expected_type, row.subject_id, row.bot_id, True, f"{expected_type}_token", row.id)
expires_at_raw = str(payload.get("expires_at") or "").strip()
if expires_at_raw:
try:
expires_at = datetime.fromisoformat(expires_at_raw.replace("Z", ""))
if expires_at <= _utcnow():
_revoke_token_hash(session, token_hash, reason="expired")
return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing")
except Exception:
pass
row_id = int(payload.get("audit_id") or 0) or None
if row_id is not None:
row = session.get(AuthLoginLog, row_id)
if row is None or row.revoked_at is not None:
fallback_row = _active_token_row(
session,
token_hash=token_hash,
expected_type=expected_type,
bot_id=normalized_bot_id,
)
if fallback_row is None:
return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing")
_touch_session(session, fallback_row)
return AuthPrincipal(
expected_type,
fallback_row.subject_id,
fallback_row.bot_id,
True,
f"{expected_type}_token",
fallback_row.id,
)
if row.expires_at <= _utcnow():
_revoke_token_hash(session, token_hash, reason="expired")
return AuthPrincipal(expected_type, "", normalized_bot_id, False, "missing")
_touch_session(session, row)
return AuthPrincipal(expected_type, subject_id, payload_bot_id, True, f"{expected_type}_token", row_id)
def resolve_panel_request_auth(session: Session, request: Request) -> AuthPrincipal:
if not _is_panel_auth_enabled():
return AuthPrincipal("panel", PANEL_SUBJECT_ID, None, True, "unprotected")
return _resolve_token_auth(session, raw_token=_read_panel_token(request), expected_type="panel")
def resolve_bot_request_auth(session: Session, request: Request, bot_id: str) -> AuthPrincipal:
normalized_bot_id = str(bot_id or "").strip()
if not normalized_bot_id:
return AuthPrincipal("bot", "", None, False, "missing")
return _resolve_token_auth(
session,
raw_token=_read_bot_token(request, normalized_bot_id),
expected_type="bot",
bot_id=normalized_bot_id,
)
def resolve_panel_websocket_auth(session: Session, websocket: WebSocket) -> AuthPrincipal:
if not _is_panel_auth_enabled():
return AuthPrincipal("panel", PANEL_SUBJECT_ID, None, True, "unprotected")
return _resolve_token_auth(session, raw_token=_read_panel_token_ws(websocket), expected_type="panel")
def resolve_bot_websocket_auth(session: Session, websocket: WebSocket, bot_id: str) -> AuthPrincipal:
normalized_bot_id = str(bot_id or "").strip()
if not normalized_bot_id:
return AuthPrincipal("bot", "", None, False, "missing")
return _resolve_token_auth(
session,
raw_token=_read_bot_token_ws(websocket, normalized_bot_id),
expected_type="bot",
bot_id=normalized_bot_id,
)