102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
|
|
import jwt
|
|||
|
|
import redis
|
|||
|
|
from datetime import datetime, timedelta
|
|||
|
|
from typing import Optional, Dict, Any
|
|||
|
|
from app.core.config import REDIS_CONFIG
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
# JWT配置
|
|||
|
|
JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY', 'your-super-secret-key-change-in-production')
|
|||
|
|
JWT_ALGORITHM = "HS256"
|
|||
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7天
|
|||
|
|
|
|||
|
|
class JWTService:
|
|||
|
|
def __init__(self):
|
|||
|
|
self.redis_client = redis.Redis(**REDIS_CONFIG)
|
|||
|
|
|
|||
|
|
def create_access_token(self, data: Dict[str, Any]) -> str:
|
|||
|
|
"""创建JWT访问令牌"""
|
|||
|
|
to_encode = data.copy()
|
|||
|
|
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|||
|
|
to_encode.update({"exp": expire, "type": "access"})
|
|||
|
|
|
|||
|
|
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
|||
|
|
|
|||
|
|
# 将token存储到Redis,用于管理和撤销
|
|||
|
|
user_id = data.get("user_id")
|
|||
|
|
if user_id:
|
|||
|
|
self.redis_client.setex(
|
|||
|
|
f"token:{user_id}:{encoded_jwt}",
|
|||
|
|
ACCESS_TOKEN_EXPIRE_MINUTES * 60, # Redis需要秒
|
|||
|
|
"active"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return encoded_jwt
|
|||
|
|
|
|||
|
|
def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
|||
|
|
"""验证JWT令牌"""
|
|||
|
|
try:
|
|||
|
|
# 解码JWT
|
|||
|
|
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
|||
|
|
|
|||
|
|
# 检查token类型
|
|||
|
|
if payload.get("type") != "access":
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
user_id = payload.get("user_id")
|
|||
|
|
if not user_id:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 检查token是否在Redis中且未被撤销
|
|||
|
|
redis_key = f"token:{user_id}:{token}"
|
|||
|
|
if not self.redis_client.exists(redis_key):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
return payload
|
|||
|
|
|
|||
|
|
except jwt.ExpiredSignatureError:
|
|||
|
|
return None
|
|||
|
|
except jwt.InvalidTokenError:
|
|||
|
|
return None
|
|||
|
|
except Exception:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def revoke_token(self, token: str, user_id: int) -> bool:
|
|||
|
|
"""撤销token"""
|
|||
|
|
try:
|
|||
|
|
redis_key = f"token:{user_id}:{token}"
|
|||
|
|
return self.redis_client.delete(redis_key) > 0
|
|||
|
|
except:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def revoke_all_user_tokens(self, user_id: int) -> int:
|
|||
|
|
"""撤销用户的所有token"""
|
|||
|
|
try:
|
|||
|
|
pattern = f"token:{user_id}:*"
|
|||
|
|
keys = self.redis_client.keys(pattern)
|
|||
|
|
if keys:
|
|||
|
|
return self.redis_client.delete(*keys)
|
|||
|
|
return 0
|
|||
|
|
except:
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
def refresh_token(self, token: str) -> Optional[str]:
|
|||
|
|
"""刷新token(可选功能)"""
|
|||
|
|
payload = self.verify_token(token)
|
|||
|
|
if not payload:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 撤销旧token
|
|||
|
|
user_id = payload.get("user_id")
|
|||
|
|
self.revoke_token(token, user_id)
|
|||
|
|
|
|||
|
|
# 创建新token
|
|||
|
|
new_data = {
|
|||
|
|
"user_id": user_id,
|
|||
|
|
"username": payload.get("username"),
|
|||
|
|
"caption": payload.get("caption")
|
|||
|
|
}
|
|||
|
|
return self.create_access_token(new_data)
|
|||
|
|
|
|||
|
|
# 全局实例
|
|||
|
|
jwt_service = JWTService()
|