nex_docus/backend/app/core/deps.py

215 lines
6.8 KiB
Python
Raw Normal View History

2025-12-20 11:18:59 +00:00
"""
认证依赖获取当前登录用户
"""
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import Optional
from app.core.database import get_db
from app.core.security import decode_access_token
from app.core.redis_client import TokenCache
from app.models.user import User
import logging
logger = logging.getLogger(__name__)
# HTTP Bearer 认证方案
security = HTTPBearer()
security_optional = HTTPBearer(auto_error=False)
async def get_current_user(
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db)
) -> User:
"""
获取当前登录用户依赖注入
"""
token = credentials.credentials
logger.info(f"Received token: {token[:20]}...") # 只记录前20个字符
# 保存 token 到请求状态,供退出登录使用
request.state.token = token
# 先验证 Redis 中是否存在该 token
user_id_from_redis = await TokenCache.get_user_id(token)
if user_id_from_redis is None:
logger.error("Token not found in Redis or expired")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="登录已过期,请重新登录",
headers={"WWW-Authenticate": "Bearer"},
)
# 解码 JWT 验证完整性
payload = decode_access_token(token)
logger.info(f"Decoded payload: {payload}")
if payload is None:
logger.error("Token decode failed: payload is None")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
user_id_str = payload.get("sub")
logger.info(f"Extracted user_id (string): {user_id_str}")
if user_id_str is None:
logger.error("user_id is None in payload")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
# 将字符串转为整数
try:
user_id = int(user_id_str)
except (ValueError, TypeError):
logger.error(f"Invalid user_id format: {user_id_str}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
# 验证 Redis 中的 user_id 与 JWT 中的是否一致
if user_id != user_id_from_redis:
logger.error(f"User ID mismatch: JWT={user_id}, Redis={user_id_from_redis}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
# 查询用户
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
logger.error(f"User not found for user_id: {user_id}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在",
)
if user.status != 1:
logger.error(f"User {user_id} is disabled")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="用户已被禁用",
)
logger.info(f"User authenticated successfully: {user.username}")
return user
async def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
"""
获取当前活跃用户
"""
return current_user
async def get_user_from_token_or_query(
request: Request,
token: Optional[str] = None, # 从query参数获取
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_optional),
db: AsyncSession = Depends(get_db)
) -> User:
"""
获取当前用户支持从query参数或header获取token
用于图片等资源访问优先使用header其次使用query参数
"""
# 优先从 header 获取 token
if credentials:
token_str = credentials.credentials
elif token:
# 从 query 参数获取
token_str = token
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未提供认证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
logger.info(f"Received token: {token_str[:20]}...")
# 保存 token 到请求状态
request.state.token = token_str
# 验证 Redis 中是否存在该 token
user_id_from_redis = await TokenCache.get_user_id(token_str)
if user_id_from_redis is None:
logger.error("Token not found in Redis or expired")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="登录已过期,请重新登录",
headers={"WWW-Authenticate": "Bearer"},
)
# 解码 JWT 验证完整性
payload = decode_access_token(token_str)
logger.info(f"Decoded payload: {payload}")
if payload is None:
logger.error("Token decode failed: payload is None")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
user_id_str = payload.get("sub")
logger.info(f"Extracted user_id (string): {user_id_str}")
if user_id_str is None:
logger.error("user_id is None in payload")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
# 将字符串转为整数
try:
user_id = int(user_id_str)
except (ValueError, TypeError):
logger.error(f"Invalid user_id format: {user_id_str}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
# 验证 Redis 中的 user_id 与 JWT 中的是否一致
if user_id != user_id_from_redis:
logger.error(f"User ID mismatch: JWT={user_id}, Redis={user_id_from_redis}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
# 查询用户
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
logger.error(f"User not found for user_id: {user_id}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在",
)
if user.status != 1:
logger.error(f"User {user_id} is disabled")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="用户已被禁用",
)
logger.info(f"User authenticated successfully: {user.username}")
return user