153 lines
4.4 KiB
Python
153 lines
4.4 KiB
Python
|
|
"""
|
|||
|
|
Redis 客户端管理
|
|||
|
|
"""
|
|||
|
|
import redis.asyncio as aioredis
|
|||
|
|
from typing import Optional
|
|||
|
|
from app.core.config import settings
|
|||
|
|
import logging
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# Redis 客户端实例
|
|||
|
|
redis_client: aioredis.Redis = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def init_redis():
|
|||
|
|
"""初始化 Redis 连接"""
|
|||
|
|
global redis_client
|
|||
|
|
try:
|
|||
|
|
redis_client = await aioredis.from_url(
|
|||
|
|
settings.REDIS_URL,
|
|||
|
|
encoding="utf-8",
|
|||
|
|
decode_responses=True,
|
|||
|
|
max_connections=10
|
|||
|
|
)
|
|||
|
|
# 测试连接
|
|||
|
|
await redis_client.ping()
|
|||
|
|
logger.info("Redis connected successfully")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Redis connection failed: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def close_redis():
|
|||
|
|
"""关闭 Redis 连接"""
|
|||
|
|
global redis_client
|
|||
|
|
if redis_client:
|
|||
|
|
await redis_client.close()
|
|||
|
|
logger.info("Redis connection closed")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_redis() -> aioredis.Redis:
|
|||
|
|
"""获取 Redis 客户端"""
|
|||
|
|
return redis_client
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Token 缓存相关操作
|
|||
|
|
class TokenCache:
|
|||
|
|
"""Token 缓存管理"""
|
|||
|
|
|
|||
|
|
TOKEN_PREFIX = "token:"
|
|||
|
|
USER_TOKEN_PREFIX = "user_tokens:"
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
async def save_token(user_id: int, token: str, expire_seconds: int = 86400):
|
|||
|
|
"""
|
|||
|
|
保存 token 到 Redis
|
|||
|
|
:param user_id: 用户ID
|
|||
|
|
:param token: JWT token
|
|||
|
|
:param expire_seconds: 过期时间(秒),默认24小时
|
|||
|
|
"""
|
|||
|
|
redis = get_redis()
|
|||
|
|
|
|||
|
|
# 保存 token -> user_id 映射
|
|||
|
|
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
|
|||
|
|
await redis.setex(token_key, expire_seconds, str(user_id))
|
|||
|
|
|
|||
|
|
# 保存 user_id -> tokens 集合(支持多设备登录)
|
|||
|
|
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id}"
|
|||
|
|
await redis.sadd(user_tokens_key, token)
|
|||
|
|
await redis.expire(user_tokens_key, expire_seconds)
|
|||
|
|
|
|||
|
|
logger.info(f"Token saved for user {user_id}, expires in {expire_seconds}s")
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
async def get_user_id(token: str) -> Optional[int]:
|
|||
|
|
"""
|
|||
|
|
根据 token 获取用户ID
|
|||
|
|
:param token: JWT token
|
|||
|
|
:return: 用户ID 或 None
|
|||
|
|
"""
|
|||
|
|
redis = get_redis()
|
|||
|
|
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
|
|||
|
|
user_id_str = await redis.get(token_key)
|
|||
|
|
|
|||
|
|
if user_id_str:
|
|||
|
|
return int(user_id_str)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
async def delete_token(token: str):
|
|||
|
|
"""
|
|||
|
|
删除指定 token
|
|||
|
|
:param token: JWT token
|
|||
|
|
"""
|
|||
|
|
redis = get_redis()
|
|||
|
|
|
|||
|
|
# 获取 user_id
|
|||
|
|
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
|
|||
|
|
user_id_str = await redis.get(token_key)
|
|||
|
|
|
|||
|
|
if user_id_str:
|
|||
|
|
user_id = int(user_id_str)
|
|||
|
|
|
|||
|
|
# 从用户 tokens 集合中移除
|
|||
|
|
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id}"
|
|||
|
|
await redis.srem(user_tokens_key, token)
|
|||
|
|
|
|||
|
|
# 删除 token
|
|||
|
|
await redis.delete(token_key)
|
|||
|
|
logger.info(f"Token deleted: {token[:20]}...")
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
async def delete_user_all_tokens(user_id: int):
|
|||
|
|
"""
|
|||
|
|
删除用户的所有 token(用于强制登出)
|
|||
|
|
:param user_id: 用户ID
|
|||
|
|
"""
|
|||
|
|
redis = get_redis()
|
|||
|
|
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id}"
|
|||
|
|
|
|||
|
|
# 获取所有 tokens
|
|||
|
|
tokens = await redis.smembers(user_tokens_key)
|
|||
|
|
|
|||
|
|
# 删除所有 token
|
|||
|
|
for token in tokens:
|
|||
|
|
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
|
|||
|
|
await redis.delete(token_key)
|
|||
|
|
|
|||
|
|
# 删除集合
|
|||
|
|
await redis.delete(user_tokens_key)
|
|||
|
|
logger.info(f"All tokens deleted for user {user_id}")
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
async def extend_token(token: str, expire_seconds: int = 86400):
|
|||
|
|
"""
|
|||
|
|
延长 token 有效期(用于 token 刷新)
|
|||
|
|
:param token: JWT token
|
|||
|
|
:param expire_seconds: 延长的过期时间(秒)
|
|||
|
|
"""
|
|||
|
|
redis = get_redis()
|
|||
|
|
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
|
|||
|
|
|
|||
|
|
# 延长过期时间
|
|||
|
|
await redis.expire(token_key, expire_seconds)
|
|||
|
|
|
|||
|
|
# 获取 user_id 并延长用户 tokens 集合过期时间
|
|||
|
|
user_id_str = await redis.get(token_key)
|
|||
|
|
if user_id_str:
|
|||
|
|
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id_str}"
|
|||
|
|
await redis.expire(user_tokens_key, expire_seconds)
|
|||
|
|
|
|||
|
|
logger.info(f"Token extended: {token[:20]}...")
|