imetting_backend/app/services/voiceprint_service.py

218 lines
7.3 KiB
Python
Raw Permalink Normal View History

2025-10-31 06:54:54 +00:00
"""
声纹服务 - 处理用户声纹采集存储和验证
"""
import os
import json
import wave
from datetime import datetime
from typing import Optional, Dict
from pathlib import Path
from app.core.database import get_db_connection
import app.core.config as config_module
class VoiceprintService:
"""声纹服务类 - 同步处理声纹采集"""
def __init__(self):
self.voiceprint_dir = config_module.VOICEPRINT_DIR
def get_user_voiceprint_status(self, user_id: int) -> Dict:
"""
获取用户声纹状态
Args:
user_id: 用户ID
Returns:
Dict: 声纹状态信息
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
query = """
SELECT vp_id, user_id, file_path, file_size, duration_seconds, collected_at, updated_at
FROM user_voiceprint
WHERE user_id = %s
"""
cursor.execute(query, (user_id,))
voiceprint = cursor.fetchone()
if voiceprint:
return {
"has_voiceprint": True,
"vp_id": voiceprint['vp_id'],
"file_path": voiceprint['file_path'],
"duration_seconds": float(voiceprint['duration_seconds']) if voiceprint['duration_seconds'] else None,
"collected_at": voiceprint['collected_at'].isoformat() if voiceprint['collected_at'] else None
}
else:
return {
"has_voiceprint": False,
"vp_id": None,
"file_path": None,
"duration_seconds": None,
"collected_at": None
}
except Exception as e:
print(f"获取声纹状态错误: {e}")
raise e
def save_voiceprint(self, user_id: int, audio_file_path: str, file_size: int) -> Dict:
"""
保存声纹文件并提取特征向量
Args:
user_id: 用户ID
audio_file_path: 音频文件路径
file_size: 文件大小
Returns:
Dict: 保存结果
"""
try:
# 1. 获取音频时长
duration = self._get_audio_duration(audio_file_path)
# 2. 提取声纹向量调用FunASR
vector_data = self._extract_voiceprint_vector(audio_file_path)
# 3. 保存到数据库
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 检查用户是否已有声纹
cursor.execute("SELECT vp_id FROM user_voiceprint WHERE user_id = %s", (user_id,))
existing = cursor.fetchone()
# 计算相对路径
relative_path = str(Path(audio_file_path).relative_to(config_module.BASE_DIR))
if existing:
# 更新现有记录
update_query = """
UPDATE user_voiceprint
SET file_path = %s, file_size = %s, duration_seconds = %s,
vector_data = %s, updated_at = NOW()
WHERE user_id = %s
"""
cursor.execute(update_query, (
relative_path, file_size, duration,
json.dumps(vector_data) if vector_data else None,
user_id
))
vp_id = existing['vp_id']
else:
# 插入新记录
insert_query = """
INSERT INTO user_voiceprint
(user_id, file_path, file_size, duration_seconds, vector_data, collected_at, updated_at)
VALUES (%s, %s, %s, %s, %s, NOW(), NOW())
"""
cursor.execute(insert_query, (
user_id, relative_path, file_size, duration,
json.dumps(vector_data) if vector_data else None
))
vp_id = cursor.lastrowid
connection.commit()
return {
"vp_id": vp_id,
"user_id": user_id,
"file_path": relative_path,
"file_size": file_size,
"duration_seconds": duration,
"has_vector": vector_data is not None
}
except Exception as e:
print(f"保存声纹错误: {e}")
raise e
def delete_voiceprint(self, user_id: int) -> bool:
"""
删除用户声纹
Args:
user_id: 用户ID
Returns:
bool: 是否删除成功
"""
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 获取文件路径
cursor.execute("SELECT file_path FROM user_voiceprint WHERE user_id = %s", (user_id,))
voiceprint = cursor.fetchone()
if voiceprint:
# 构建完整文件路径
relative_path = voiceprint['file_path']
if relative_path.startswith('/'):
relative_path = relative_path.lstrip('/')
file_path = config_module.BASE_DIR / relative_path
# 删除数据库记录
cursor.execute("DELETE FROM user_voiceprint WHERE user_id = %s", (user_id,))
connection.commit()
# 删除文件
if file_path.exists():
os.remove(file_path)
return True
else:
return False
except Exception as e:
print(f"删除声纹错误: {e}")
raise e
def _get_audio_duration(self, audio_file_path: str) -> float:
"""
获取音频文件时长
Args:
audio_file_path: 音频文件路径
Returns:
float: 时长
"""
try:
with wave.open(audio_file_path, 'rb') as wav_file:
frames = wav_file.getnframes()
rate = wav_file.getframerate()
duration = frames / float(rate)
return round(duration, 2)
except Exception as e:
print(f"获取音频时长错误: {e}")
return 10.0 # 默认返回10秒
def _extract_voiceprint_vector(self, audio_file_path: str) -> Optional[list]:
"""
提取声纹特征向量调用FunASR
Args:
audio_file_path: 音频文件路径
Returns:
Optional[list]: 声纹向量192失败返回None
"""
# TODO: 集成FunASR的说话人识别模型
# 使用 speech_campplus_sv_zh-cn_16k-common 模型
# 返回192维的embedding向量
print(f"[TODO] 调用FunASR提取声纹向量: {audio_file_path}")
# 暂时返回None等待FunASR集成
# 集成后应该返回类似: [0.123, -0.456, 0.789, ...]
return None
# 创建全局实例
voiceprint_service = VoiceprintService()