nex_docus/backend/app/core/redis_client.py

153 lines
4.4 KiB
Python
Raw Normal View History

2025-12-20 11:18:59 +00:00
"""
Redis 客户端管理
"""
import redis.asyncio as aioredis
from typing import Optional
from app.core.config import settings
import logging
logger = logging.getLogger(__name__)
# Redis 客户端实例
redis_client: aioredis.Redis = None
async def init_redis():
"""初始化 Redis 连接"""
global redis_client
try:
redis_client = await aioredis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
max_connections=10
)
# 测试连接
await redis_client.ping()
logger.info("Redis connected successfully")
except Exception as e:
logger.error(f"Redis connection failed: {e}")
raise
async def close_redis():
"""关闭 Redis 连接"""
global redis_client
if redis_client:
await redis_client.close()
logger.info("Redis connection closed")
def get_redis() -> aioredis.Redis:
"""获取 Redis 客户端"""
return redis_client
# Token 缓存相关操作
class TokenCache:
"""Token 缓存管理"""
TOKEN_PREFIX = "token:"
USER_TOKEN_PREFIX = "user_tokens:"
@staticmethod
async def save_token(user_id: int, token: str, expire_seconds: int = 86400):
"""
保存 token Redis
:param user_id: 用户ID
:param token: JWT token
:param expire_seconds: 过期时间默认24小时
"""
redis = get_redis()
# 保存 token -> user_id 映射
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
await redis.setex(token_key, expire_seconds, str(user_id))
# 保存 user_id -> tokens 集合(支持多设备登录)
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id}"
await redis.sadd(user_tokens_key, token)
await redis.expire(user_tokens_key, expire_seconds)
logger.info(f"Token saved for user {user_id}, expires in {expire_seconds}s")
@staticmethod
async def get_user_id(token: str) -> Optional[int]:
"""
根据 token 获取用户ID
:param token: JWT token
:return: 用户ID None
"""
redis = get_redis()
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
user_id_str = await redis.get(token_key)
if user_id_str:
return int(user_id_str)
return None
@staticmethod
async def delete_token(token: str):
"""
删除指定 token
:param token: JWT token
"""
redis = get_redis()
# 获取 user_id
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
user_id_str = await redis.get(token_key)
if user_id_str:
user_id = int(user_id_str)
# 从用户 tokens 集合中移除
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id}"
await redis.srem(user_tokens_key, token)
# 删除 token
await redis.delete(token_key)
logger.info(f"Token deleted: {token[:20]}...")
@staticmethod
async def delete_user_all_tokens(user_id: int):
"""
删除用户的所有 token用于强制登出
:param user_id: 用户ID
"""
redis = get_redis()
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id}"
# 获取所有 tokens
tokens = await redis.smembers(user_tokens_key)
# 删除所有 token
for token in tokens:
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
await redis.delete(token_key)
# 删除集合
await redis.delete(user_tokens_key)
logger.info(f"All tokens deleted for user {user_id}")
@staticmethod
async def extend_token(token: str, expire_seconds: int = 86400):
"""
延长 token 有效期用于 token 刷新
:param token: JWT token
:param expire_seconds: 延长的过期时间
"""
redis = get_redis()
token_key = f"{TokenCache.TOKEN_PREFIX}{token}"
# 延长过期时间
await redis.expire(token_key, expire_seconds)
# 获取 user_id 并延长用户 tokens 集合过期时间
user_id_str = await redis.get(token_key)
if user_id_str:
user_tokens_key = f"{TokenCache.USER_TOKEN_PREFIX}{user_id_str}"
await redis.expire(user_tokens_key, expire_seconds)
logger.info(f"Token extended: {token[:20]}...")