215 lines
6.8 KiB
Python
215 lines
6.8 KiB
Python
|
|
"""
|
|||
|
|
认证依赖:获取当前登录用户
|
|||
|
|
"""
|
|||
|
|
from fastapi import Depends, HTTPException, status, Request
|
|||
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
from sqlalchemy import select
|
|||
|
|
from typing import Optional
|
|||
|
|
from app.core.database import get_db
|
|||
|
|
from app.core.security import decode_access_token
|
|||
|
|
from app.core.redis_client import TokenCache
|
|||
|
|
from app.models.user import User
|
|||
|
|
import logging
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# HTTP Bearer 认证方案
|
|||
|
|
security = HTTPBearer()
|
|||
|
|
security_optional = HTTPBearer(auto_error=False)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def get_current_user(
|
|||
|
|
request: Request,
|
|||
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|||
|
|
db: AsyncSession = Depends(get_db)
|
|||
|
|
) -> User:
|
|||
|
|
"""
|
|||
|
|
获取当前登录用户(依赖注入)
|
|||
|
|
"""
|
|||
|
|
token = credentials.credentials
|
|||
|
|
logger.info(f"Received token: {token[:20]}...") # 只记录前20个字符
|
|||
|
|
|
|||
|
|
# 保存 token 到请求状态,供退出登录使用
|
|||
|
|
request.state.token = token
|
|||
|
|
|
|||
|
|
# 先验证 Redis 中是否存在该 token
|
|||
|
|
user_id_from_redis = await TokenCache.get_user_id(token)
|
|||
|
|
if user_id_from_redis is None:
|
|||
|
|
logger.error("Token not found in Redis or expired")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="登录已过期,请重新登录",
|
|||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 解码 JWT 验证完整性
|
|||
|
|
payload = decode_access_token(token)
|
|||
|
|
logger.info(f"Decoded payload: {payload}")
|
|||
|
|
|
|||
|
|
if payload is None:
|
|||
|
|
logger.error("Token decode failed: payload is None")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="无效的认证凭证",
|
|||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
user_id_str = payload.get("sub")
|
|||
|
|
logger.info(f"Extracted user_id (string): {user_id_str}")
|
|||
|
|
|
|||
|
|
if user_id_str is None:
|
|||
|
|
logger.error("user_id is None in payload")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="无效的认证凭证",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 将字符串转为整数
|
|||
|
|
try:
|
|||
|
|
user_id = int(user_id_str)
|
|||
|
|
except (ValueError, TypeError):
|
|||
|
|
logger.error(f"Invalid user_id format: {user_id_str}")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="无效的认证凭证",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证 Redis 中的 user_id 与 JWT 中的是否一致
|
|||
|
|
if user_id != user_id_from_redis:
|
|||
|
|
logger.error(f"User ID mismatch: JWT={user_id}, Redis={user_id_from_redis}")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="无效的认证凭证",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 查询用户
|
|||
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|||
|
|
user = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if user is None:
|
|||
|
|
logger.error(f"User not found for user_id: {user_id}")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="用户不存在",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if user.status != 1:
|
|||
|
|
logger.error(f"User {user_id} is disabled")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|||
|
|
detail="用户已被禁用",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger.info(f"User authenticated successfully: {user.username}")
|
|||
|
|
return user
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def get_current_active_user(
|
|||
|
|
current_user: User = Depends(get_current_user)
|
|||
|
|
) -> User:
|
|||
|
|
"""
|
|||
|
|
获取当前活跃用户
|
|||
|
|
"""
|
|||
|
|
return current_user
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def get_user_from_token_or_query(
|
|||
|
|
request: Request,
|
|||
|
|
token: Optional[str] = None, # 从query参数获取
|
|||
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_optional),
|
|||
|
|
db: AsyncSession = Depends(get_db)
|
|||
|
|
) -> User:
|
|||
|
|
"""
|
|||
|
|
获取当前用户(支持从query参数或header获取token)
|
|||
|
|
用于图片等资源访问,优先使用header,其次使用query参数
|
|||
|
|
"""
|
|||
|
|
# 优先从 header 获取 token
|
|||
|
|
if credentials:
|
|||
|
|
token_str = credentials.credentials
|
|||
|
|
elif token:
|
|||
|
|
# 从 query 参数获取
|
|||
|
|
token_str = token
|
|||
|
|
else:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="未提供认证凭证",
|
|||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger.info(f"Received token: {token_str[:20]}...")
|
|||
|
|
|
|||
|
|
# 保存 token 到请求状态
|
|||
|
|
request.state.token = token_str
|
|||
|
|
|
|||
|
|
# 验证 Redis 中是否存在该 token
|
|||
|
|
user_id_from_redis = await TokenCache.get_user_id(token_str)
|
|||
|
|
if user_id_from_redis is None:
|
|||
|
|
logger.error("Token not found in Redis or expired")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="登录已过期,请重新登录",
|
|||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 解码 JWT 验证完整性
|
|||
|
|
payload = decode_access_token(token_str)
|
|||
|
|
logger.info(f"Decoded payload: {payload}")
|
|||
|
|
|
|||
|
|
if payload is None:
|
|||
|
|
logger.error("Token decode failed: payload is None")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="无效的认证凭证",
|
|||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
user_id_str = payload.get("sub")
|
|||
|
|
logger.info(f"Extracted user_id (string): {user_id_str}")
|
|||
|
|
|
|||
|
|
if user_id_str is None:
|
|||
|
|
logger.error("user_id is None in payload")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="无效的认证凭证",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 将字符串转为整数
|
|||
|
|
try:
|
|||
|
|
user_id = int(user_id_str)
|
|||
|
|
except (ValueError, TypeError):
|
|||
|
|
logger.error(f"Invalid user_id format: {user_id_str}")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="无效的认证凭证",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证 Redis 中的 user_id 与 JWT 中的是否一致
|
|||
|
|
if user_id != user_id_from_redis:
|
|||
|
|
logger.error(f"User ID mismatch: JWT={user_id}, Redis={user_id_from_redis}")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="无效的认证凭证",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 查询用户
|
|||
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|||
|
|
user = result.scalar_one_or_none()
|
|||
|
|
|
|||
|
|
if user is None:
|
|||
|
|
logger.error(f"User not found for user_id: {user_id}")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="用户不存在",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if user.status != 1:
|
|||
|
|
logger.error(f"User {user_id} is disabled")
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|||
|
|
detail="用户已被禁用",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger.info(f"User authenticated successfully: {user.username}")
|
|||
|
|
return user
|