nex_basse/backend/app/services/meeting_service.py

728 lines
33 KiB
Python
Raw Normal View History

2026-02-25 08:48:31 +00:00
from sqlalchemy.orm import Session
2026-03-02 10:26:22 +00:00
from app.core.db import SessionLocal
from typing import Callable, Awaitable, Optional, List
from app.models import Meeting, SummarizeTask, PromptTemplate, AIModel, TranscriptSegment, TranscriptTask, MeetingAudio, Hotword
2026-02-25 08:48:31 +00:00
from app.services.llm_service import LLMService
import uuid
import json
import logging
import math
2026-03-02 10:26:22 +00:00
import httpx
import asyncio
2026-02-25 08:48:31 +00:00
from datetime import datetime
from app.core.redis import redis_client
2026-03-02 10:26:22 +00:00
from app.core.config import get_settings
from pathlib import Path
2026-02-25 08:48:31 +00:00
logger = logging.getLogger(__name__)
2026-03-02 10:26:22 +00:00
settings = get_settings()
2026-02-25 08:48:31 +00:00
class MeetingService:
2026-03-02 10:26:22 +00:00
@staticmethod
async def create_transcript_task(
db: Session,
meeting_id: int,
model_id: int,
language: str = "auto"
):
# 1. 验证会议和音频
meeting = db.query(Meeting).filter(Meeting.meeting_id == meeting_id).first()
if not meeting:
raise Exception("Meeting not found")
# 获取最新的音频文件
audio = db.query(MeetingAudio).filter(
MeetingAudio.meeting_id == meeting_id
).order_by(MeetingAudio.upload_time.desc()).first()
if not audio:
raise Exception("No audio file found for this meeting")
# 2. 创建转译任务
task_id = str(uuid.uuid4())
new_task = TranscriptTask(
task_id=task_id,
meeting_id=meeting_id,
model_id=model_id,
language=language,
status="pending",
progress=0,
created_at=datetime.utcnow()
)
db.add(new_task)
# 更新会议状态
meeting.status = "transcribing"
db.commit()
# 3. 进入 Redis 队列 (transcribe queue)
# Note: In a real worker system, we would push to a queue.
# For now, we return the task so the API can trigger background processing.
# await redis_client.lpush("meeting:transcribe:queue", task_id)
return new_task
@staticmethod
async def process_transcript_task(db: Session, task_id: str):
print(f"[DEBUG] Processing transcript task {task_id}")
task = db.query(TranscriptTask).filter(TranscriptTask.task_id == task_id).first()
if not task:
print(f"[ERROR] Task {task_id} not found in DB")
return
try:
task.status = "processing"
task.progress = 10
await asyncio.to_thread(db.commit)
# 1. 获取模型配置
model_config = db.query(AIModel).filter(AIModel.model_id == task.model_id).first()
if not model_config:
# 如果没有指定模型,尝试找一个默认的 ASR 模型
model_config = db.query(AIModel).filter(AIModel.model_type == "asr").first()
if not model_config:
raise Exception("No ASR model configuration found")
# 2. 获取音频文件
audio = db.query(MeetingAudio).filter(
MeetingAudio.meeting_id == task.meeting_id
).order_by(MeetingAudio.upload_time.desc()).first()
if not audio:
raise Exception("Audio file missing")
task.progress = 20
await asyncio.to_thread(db.commit)
# 3. 调用 ASR 服务 (Local Model Only)
# Use model config base_url if available, otherwise fallback to settings
asr_base_url = model_config.base_url if model_config.base_url else settings.asr_api_base_url
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
if not meeting:
raise Exception("Meeting not found")
hotword_filters = [Hotword.scope.in_(["public", "global"])]
if meeting.user_id:
hotword_filters.append(
(Hotword.scope == "personal") & (Hotword.user_id == meeting.user_id)
)
hotwords = db.query(Hotword).filter(
(hotword_filters[0]) if len(hotword_filters) == 1 else (hotword_filters[0] | hotword_filters[1])
).all()
hotword_entries = []
hotword_string_parts = []
for hw in hotwords:
hotword_entries.append({"word": hw.word, "weight": hw.weight})
if hw.weight and hw.weight != 1:
hotword_string_parts.append(f"{hw.word}:{hw.weight}")
else:
hotword_string_parts.append(hw.word)
hotword_string = " ".join([p for p in hotword_string_parts if p])
logger.info(f"Task {task_id}: Starting transcription with ASR Model at {asr_base_url}")
print(f"[DEBUG] Calling ASR at {asr_base_url} for file: {audio.file_path}")
# Define progress callback
async def update_progress(p: int, msg: str = None):
# Map ASR progress (0-100) to Task progress (20-80)
# Ensure we don't go backwards or exceed bounds
new_progress = 20 + int(p * 0.6)
print(f"[DEBUG] update_progress called: p={p}, new_progress={new_progress}, current={task.progress}")
if new_progress > task.progress:
task.progress = new_progress
# Note: We must be careful with db.commit() in async loop if db session is shared
try:
# Offload blocking DB commit to thread to avoid freezing the event loop
await asyncio.to_thread(db.commit)
print(f"[DEBUG] DB Updated: progress={new_progress}")
except Exception as dbe:
logger.error(f"DB Commit failed: {dbe}")
db.rollback()
# Update status message in Redis
if msg:
try:
await redis_client.setex(f"task:status:{task_id}", 3600, msg)
print(f"[DEBUG] Redis Updated: key=task:status:{task_id}, msg={msg}")
except Exception as e:
logger.warning(f"Failed to update task status in Redis: {e}")
# 使用本地模型 API 进行转译
segments_data = await MeetingService._call_local_asr_api(
audio_path=audio.file_path,
base_url=asr_base_url,
language=task.language,
hotwords=hotword_entries if hotword_entries else None,
hotword_string=hotword_string,
progress_callback=update_progress
)
print(f"[DEBUG] Received {len(segments_data)} segments from ASR")
task.progress = 80
await asyncio.to_thread(db.commit)
# 4. 保存转译结果
last_saved_end_ms = 0
def parse_number(v):
if isinstance(v, (int, float)):
return float(v)
if isinstance(v, str):
try:
return float(v.strip())
except Exception:
return None
return None
for seg in segments_data:
# Local API returns: { "text": "...", "timestamp": [[start, end]], "speaker": "..." }
# We need to adapt it to our schema
text = (seg.get("text") or "").strip()
if not text:
continue
start_ms = None
end_ms = None
ts = seg.get("timestamp")
if ts:
if isinstance(ts, list) and len(ts) > 0:
try:
if isinstance(ts[0], list):
starts = []
ends = []
for pair in ts:
if not isinstance(pair, (list, tuple)) or len(pair) < 2:
continue
s = parse_number(pair[0])
e = parse_number(pair[1])
if s is not None and e is not None:
starts.append(s)
ends.append(e)
if starts and ends:
raw_start = min(starts)
raw_end = max(ends)
if raw_end < raw_start:
raise ValueError("timestamp end < start")
if raw_end < 1000:
raw_start *= 1000.0
raw_end *= 1000.0
start_ms = int(raw_start)
end_ms = int(raw_end)
elif len(ts) >= 2:
raw_start = parse_number(ts[0])
raw_end = parse_number(ts[1])
if raw_start is None or raw_end is None:
raise ValueError("timestamp not numeric")
if raw_end < raw_start:
raise ValueError("timestamp end < start")
if raw_end < 1000:
raw_start *= 1000.0
raw_end *= 1000.0
start_ms = int(raw_start)
end_ms = int(raw_end)
except Exception:
start_ms = None
end_ms = None
if start_ms is None or end_ms is None:
bt = parse_number(seg.get("begin_time") or seg.get("start_time"))
et = parse_number(seg.get("end_time"))
if bt is not None and et is not None:
if bt < 1000 and et < 1000:
bt *= 1000.0
et *= 1000.0
start_ms = int(bt)
end_ms = int(et)
if start_ms is None or end_ms is None or end_ms < start_ms:
start_ms = last_saved_end_ms + 1
end_ms = start_ms + 1000
last_saved_end_ms = max(last_saved_end_ms, end_ms)
transcript_segment = TranscriptSegment(
meeting_id=task.meeting_id,
audio_id=audio.audio_id,
speaker_id=0,
speaker_tag=seg.get("speaker", "Unknown"),
start_time_ms=start_ms,
end_time_ms=end_ms,
text_content=text
)
db.add(transcript_segment)
# 5. 完成任务
task.status = "completed"
task.progress = 100
task.completed_at = datetime.utcnow()
# 更新音频状态
audio.processing_status = "completed"
# 更新会议状态 (如果需要自动开始总结,可以在这里触发,或者由用户触发)
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
if meeting:
meeting.status = "transcribed" # distinct status before summarizing
# Auto-trigger summarization for uploaded meetings
if meeting.type == 'upload':
try:
# Find default LLM model
llm_model = None
if meeting.summary_model_id:
llm_model = db.query(AIModel).filter(
AIModel.model_id == meeting.summary_model_id,
AIModel.status == 1
).first()
if not llm_model:
llm_model = db.query(AIModel).filter(
AIModel.model_type == 'llm',
AIModel.is_default == 1,
AIModel.status == 1
).first()
prompt_tmpl = None
if meeting.summary_prompt_id:
prompt_tmpl = db.query(PromptTemplate).filter(
PromptTemplate.id == meeting.summary_prompt_id
).first()
if not prompt_tmpl:
prompt_tmpl = db.query(PromptTemplate).filter(
PromptTemplate.status == 1
).order_by(PromptTemplate.is_system.desc(), PromptTemplate.id.asc()).first()
if llm_model and prompt_tmpl:
logger.info(f"Auto-triggering summary for meeting {meeting.meeting_id}")
# Create summarize task
# We need to call create_summarize_task but it's an async static method
# and we are inside an async static method.
# However, create_summarize_task commits to DB, so we should be careful with session.
# Since we are in the same session `db`, we can just call it.
# But create_summarize_task takes `db` session.
# Note: create_summarize_task commits. We should commit our changes first.
await asyncio.to_thread(db.commit)
new_sum_task = await MeetingService.create_summarize_task(
db,
meeting_id=meeting.meeting_id,
prompt_id=prompt_tmpl.id,
model_id=llm_model.model_id
)
# Trigger background worker for summary
# Since we are in a worker, we can just call process_summarize_task directly or via background task
# But better to use the same mechanism (add to background tasks if possible, or just call it)
# Here we don't have access to FastAPI BackgroundTasks object.
# We can use asyncio.create_task to run it in background
# import asyncio # Already imported at top level
# Use a separate worker method that creates its own session
# to avoid "Session is closed" error since the current session
# will be closed when this function returns.
asyncio.create_task(MeetingService.run_summarize_worker(new_sum_task.task_id))
# Update meeting status to summarizing
meeting.status = "summarizing"
else:
logger.warning(f"Skipping auto-summary: No default LLM or Prompt found (LLM: {llm_model}, Prompt: {prompt_tmpl})")
except Exception as sum_e:
logger.error(f"Failed to auto-trigger summary: {sum_e}")
await asyncio.to_thread(db.commit)
logger.info(f"Task {task_id} transcription completed")
except Exception as e:
logger.error(f"Task {task_id} failed: {str(e)}")
task.status = "failed"
task.error_message = str(e)
# Update audio status
audio = db.query(MeetingAudio).filter(
MeetingAudio.meeting_id == task.meeting_id
).order_by(MeetingAudio.upload_time.desc()).first()
if audio:
audio.processing_status = "failed"
audio.error_message = str(e)
# Update meeting status to failed so frontend knows to stop polling
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
if meeting:
meeting.status = "failed"
await asyncio.to_thread(db.commit)
@staticmethod
async def _call_local_asr_api(
audio_path: str,
base_url: str = "http://localhost:3050",
language: str = "auto",
hotwords: Optional[List[dict]] = None,
hotword_string: Optional[str] = None,
progress_callback: Optional[Callable[[int, Optional[str]], Awaitable[None]]] = None
) -> list:
"""
Call local ASR API for transcription.
Flow: Create Task -> Poll Status -> Get Result
"""
# import asyncio # Already imported at top level
create_url = f"{base_url.rstrip('/')}/api/tasks/recognition"
async with httpx.AsyncClient(timeout=30.0) as client:
# 1. Create Task
try:
normalized_hotwords = {}
2026-03-02 10:26:22 +00:00
if hotwords:
for hw in hotwords:
if isinstance(hw, dict):
word = hw.get("word")
weight = hw.get("weight")
if word:
normalized_hotwords[word] = int(weight) if weight is not None else 20
2026-03-02 10:26:22 +00:00
elif isinstance(hw, str):
normalized_hotwords[hw] = 20
2026-03-02 10:26:22 +00:00
payload = {
"file_path": audio_path,
"language": language,
"use_spk_id": True
}
if normalized_hotwords:
payload["hotwords"] = normalized_hotwords
elif hotword_string:
payload["hotword"] = hotword_string
response = await client.post(create_url, json=payload)
if response.status_code >= 500:
base_payload = {
"file_path": audio_path,
"language": language,
"use_spk_id": True
}
response = await client.post(create_url, json=base_payload)
if response.status_code == 422:
alt_payload = {
"audio_path": audio_path,
"language": language,
"use_spk_id": True
}
if normalized_hotwords:
alt_payload["hotwords"] = normalized_hotwords
elif hotword_string:
alt_payload["hotword"] = hotword_string
response = await client.post(create_url, json=alt_payload)
if response.status_code >= 500:
alt_base_payload = {
"audio_path": audio_path,
"language": language,
"use_spk_id": True
}
response = await client.post(create_url, json=alt_base_payload)
if response.status_code == 422 or response.status_code >= 500:
with open(audio_path, "rb") as f:
files = {
"file": (Path(audio_path).name, f, "application/octet-stream")
}
data_fields = {
"language": language,
"use_spk_id": "true"
}
if normalized_hotwords:
data_fields["hotwords"] = json.dumps(normalized_hotwords, ensure_ascii=False)
elif hotword_string:
data_fields["hotword"] = hotword_string
response = await client.post(create_url, data=data_fields, files=files)
if response.status_code >= 400:
try:
logger.error(f"ASR create response: {response.status_code} {response.text}")
except Exception:
pass
response.raise_for_status()
data = response.json()
# Handle nested data structure {code: 200, data: {task_id: ...}}
if "data" in data and isinstance(data["data"], dict) and "task_id" in data["data"]:
task_id = data["data"]["task_id"]
else:
task_id = data.get("task_id")
if not task_id:
raise Exception("Failed to get task_id from ASR service")
except Exception as e:
logger.error(f"Failed to create ASR task: {e}")
# Fallback for testing/mock if service is not running
if "Connection refused" in str(e) or "ConnectError" in str(e):
logger.warning("ASR Service not reachable, using mock data for testing")
await asyncio.sleep(2)
return [
{"timestamp": [[0, 2500]], "text": "This is a mock transcription (Local ASR unreachable).", "speaker": "System"},
{"timestamp": [[2500, 5000]], "text": "Please ensure the ASR service is running at localhost:3050.", "speaker": "System"}
]
raise
# 2. Poll Status
status_url = f"{base_url.rstrip('/')}/api/tasks/{task_id}"
# Increase timeout for polling
poll_client = httpx.AsyncClient(timeout=10.0)
max_retries = 600 # 20 minutes max wait (assuming 2s sleep)
try:
for _ in range(max_retries):
try:
res = await poll_client.get(status_url)
res.raise_for_status()
status_data = res.json()
progress = 0
# Handle nested data structure
if "data" in status_data and isinstance(status_data["data"], dict):
inner_data = status_data["data"]
status = inner_data.get("status")
result = inner_data.get("result", {})
error_msg = inner_data.get("msg") or status_data.get("msg")
# Try to get progress
if "progress" in inner_data:
raw_progress = inner_data["progress"]
elif "percent" in inner_data:
raw_progress = inner_data["percent"]
elif "percentage" in inner_data:
raw_progress = inner_data["percentage"]
else:
raw_progress = 0
# Update error_msg with message if available
if "message" in inner_data and inner_data["message"]:
error_msg = inner_data["message"]
else:
status = status_data.get("status")
result = status_data.get("result", {})
error_msg = status_data.get("msg")
# Try to get progress
if "progress" in status_data:
raw_progress = status_data["progress"]
elif "percent" in status_data:
raw_progress = status_data["percent"]
elif "percentage" in status_data:
raw_progress = status_data["percentage"]
else:
raw_progress = 0
# Update error_msg with message if available
if "message" in status_data and status_data["message"]:
error_msg = status_data["message"]
# Handle float progress (0.0 - 1.0) or int (0-100)
progress = 0
try:
if isinstance(raw_progress, (int, float)):
if 0.0 <= raw_progress <= 1.0 and isinstance(raw_progress, float):
progress = int(raw_progress * 100)
else:
progress = int(raw_progress)
elif isinstance(raw_progress, str):
# Handle string like "45%" or "0.45"
if raw_progress.endswith("%"):
progress = int(float(raw_progress.strip("%")))
else:
val = float(raw_progress)
if 0.0 <= val <= 1.0:
progress = int(val * 100)
else:
progress = int(val)
except (ValueError, TypeError):
progress = 0
# Log raw status for debugging
print(f"[DEBUG] ASR Polling: status={status}, raw_progress={raw_progress}, progress={progress}, msg={error_msg}")
# Update progress if callback provided
if progress_callback:
try:
status_msg = error_msg if error_msg else f"正在转写中... {progress}%"
print(f"[DEBUG] Invoking callback with progress={progress}, msg={status_msg}")
await progress_callback(progress, status_msg)
except Exception as e:
logger.warning(f"Progress callback error: {e}")
if status == "completed" or status == "success":
# Return the segments list.
if isinstance(result, dict) and "segments" in result:
return result["segments"]
elif isinstance(result, list):
return result
else:
# Fallback or empty
return []
elif status == "failed":
raise Exception(f"ASR Task failed: {error_msg}")
# Still processing
await asyncio.sleep(2)
except httpx.RequestError as e:
logger.warning(f"Error polling ASR task {task_id}: {e}")
await asyncio.sleep(2)
finally:
await poll_client.aclose()
raise Exception("ASR Task timed out")
2026-02-25 08:48:31 +00:00
@staticmethod
async def create_summarize_task(
db: Session,
meeting_id: int,
prompt_id: int,
model_id: int,
extra_prompt: str = ""
):
# 1. 基础数据校验
meeting = db.query(Meeting).filter(Meeting.meeting_id == meeting_id).first()
if not meeting:
raise Exception("Meeting not found")
# 2. 格式化会议转译内容 (作为 user_prompt 素材)
segments = db.query(TranscriptSegment).filter(
TranscriptSegment.meeting_id == meeting_id
).order_by(TranscriptSegment.start_time_ms).all()
formatted_content = []
for s in segments:
secs = int(s.start_time_ms // 1000)
m, sc = divmod(secs, 60)
timestamp = f"[{m:02d}:{sc:02d}]"
speaker = s.speaker_tag or f"发言人{s.speaker_id or '?'}"
formatted_content.append(f"{timestamp} {speaker}: {s.text_content}")
meeting_text = "\n".join(formatted_content)
# 组合最终 user_prompt (素材 + 用户的附加要求)
user_prompt_content = f"### 会议转译内容 ###\n{meeting_text}"
if extra_prompt:
user_prompt_content += f"\n\n### 用户的额外指令 ###\n{extra_prompt}"
# 3. 创建任务记录 (按照数据库实际字段)
task_id = str(uuid.uuid4())
new_task = SummarizeTask(
task_id=task_id,
meeting_id=meeting_id,
prompt_id=prompt_id,
model_id=model_id,
user_prompt=user_prompt_content,
status="pending",
progress=0,
created_at=datetime.utcnow()
)
db.add(new_task)
# 更新会议状态为“总结中”
meeting.status = "summarizing"
2026-03-02 10:26:22 +00:00
await asyncio.to_thread(db.commit)
2026-02-25 08:48:31 +00:00
# 4. 进入 Redis 队列
await redis_client.lpush("meeting:summarize:queue", task_id)
return new_task
2026-03-02 10:26:22 +00:00
@staticmethod
async def run_summarize_worker(task_id: str):
"""
Worker entry point for summarize task that manages its own DB session.
This is safe to call via asyncio.create_task from other contexts.
"""
db = SessionLocal()
try:
await MeetingService.process_summarize_task(db, task_id)
except Exception as e:
logger.error(f"Summarize worker (task {task_id}) failed: {e}")
finally:
db.close()
2026-02-25 08:48:31 +00:00
@staticmethod
async def process_summarize_task(db: Session, task_id: str):
"""
后台 Worker 真实执行逻辑
"""
task = db.query(SummarizeTask).filter(SummarizeTask.task_id == task_id).first()
if not task:
return
try:
task.status = "processing"
task.progress = 15
2026-03-02 10:26:22 +00:00
await asyncio.to_thread(db.commit)
2026-02-25 08:48:31 +00:00
# 1. 获取模型配置
model_config = db.query(AIModel).filter(AIModel.model_id == task.model_id).first()
if not model_config:
raise Exception("AI 模型配置不存在")
# 2. 实时获取提示词模板内容
prompt_tmpl = db.query(PromptTemplate).filter(PromptTemplate.id == task.prompt_id).first()
system_prompt = prompt_tmpl.content if prompt_tmpl else "请根据提供的会议转译内容生成准确的总结。"
task.progress = 30
2026-03-02 10:26:22 +00:00
await asyncio.to_thread(db.commit)
2026-02-25 08:48:31 +00:00
# 3. 构建消息结构
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": task.user_prompt}
]
# 解析模型参数
config_params = model_config.config or {}
temperature = config_params.get("temperature", 0.7)
top_p = config_params.get("top_p", 0.9)
logger.info(f"Task {task_id}: Launching LLM request to {model_config.model_name}")
task.progress = 50
2026-03-02 10:26:22 +00:00
await asyncio.to_thread(db.commit)
2026-02-25 08:48:31 +00:00
# 4. 调用大模型服务
summary_result = await LLMService.chat_completion(
api_key=model_config.api_key,
base_url=model_config.base_url or "https://api.openai.com/v1",
model_name=model_config.model_name,
messages=messages,
api_path=model_config.api_path or "/chat/completions",
temperature=float(temperature),
top_p=float(top_p)
)
# 5. 任务完成,回写结果
task.result = summary_result
task.status = "completed"
task.progress = 100
task.completed_at = datetime.utcnow()
# 同步更新会议主表摘要和状态
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
if meeting:
meeting.summary = summary_result
meeting.status = "completed"
2026-03-02 10:26:22 +00:00
await asyncio.to_thread(db.commit)
2026-02-25 08:48:31 +00:00
logger.info(f"Task {task_id} completed successfully")
except Exception as e:
logger.error(f"Task {task_id} execution error: {str(e)}")
task.status = "failed"
task.error_message = str(e)
2026-03-02 10:26:22 +00:00
# Restore meeting status to transcribed (so user can retry summary) instead of draft
2026-02-25 08:48:31 +00:00
meeting = db.query(Meeting).filter(Meeting.meeting_id == task.meeting_id).first()
if meeting:
2026-03-02 10:26:22 +00:00
meeting.status = "transcribed"
2026-02-25 08:48:31 +00:00
2026-03-02 10:26:22 +00:00
await asyncio.to_thread(db.commit)
2026-02-25 08:48:31 +00:00
@staticmethod
def get_task_status(db: Session, task_id: str):
2026-03-02 10:26:22 +00:00
task = db.query(SummarizeTask).filter(SummarizeTask.task_id == task_id).first()
if task:
return task
return db.query(TranscriptTask).filter(TranscriptTask.task_id == task_id).first()