127 lines
4.8 KiB
Python
127 lines
4.8 KiB
Python
import asyncio
|
|
import os
|
|
import tempfile
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
from fastapi import HTTPException, UploadFile
|
|
from sqlmodel import Session
|
|
|
|
from core.speech_service import SpeechDisabledError, SpeechDurationError, SpeechServiceError
|
|
from models.bot import BotInstance
|
|
|
|
|
|
class SpeechTranscriptionService:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
data_root: str,
|
|
speech_service: Any,
|
|
get_speech_runtime_settings: Callable[[], Dict[str, Any]],
|
|
logger: Any,
|
|
) -> None:
|
|
self._data_root = data_root
|
|
self._speech_service = speech_service
|
|
self._get_speech_runtime_settings = get_speech_runtime_settings
|
|
self._logger = logger
|
|
|
|
def _require_bot(self, *, session: Session, bot_id: str) -> BotInstance:
|
|
bot = session.get(BotInstance, bot_id)
|
|
if not bot:
|
|
raise HTTPException(status_code=404, detail="Bot not found")
|
|
return bot
|
|
|
|
async def transcribe(
|
|
self,
|
|
*,
|
|
session: Session,
|
|
bot_id: str,
|
|
file: UploadFile,
|
|
language: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
self._require_bot(session=session, bot_id=bot_id)
|
|
speech_settings = self._get_speech_runtime_settings()
|
|
if not speech_settings["enabled"]:
|
|
raise HTTPException(status_code=400, detail="Speech recognition is disabled")
|
|
if not file:
|
|
raise HTTPException(status_code=400, detail="no audio file uploaded")
|
|
|
|
original_name = str(file.filename or "audio.webm").strip() or "audio.webm"
|
|
safe_name = os.path.basename(original_name).replace("\\", "_").replace("/", "_")
|
|
ext = os.path.splitext(safe_name)[1].strip().lower() or ".webm"
|
|
if len(ext) > 12:
|
|
ext = ".webm"
|
|
|
|
tmp_path = ""
|
|
try:
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=ext, prefix=".speech_", dir=self._data_root) as tmp:
|
|
tmp_path = tmp.name
|
|
while True:
|
|
chunk = await file.read(1024 * 1024)
|
|
if not chunk:
|
|
break
|
|
tmp.write(chunk)
|
|
|
|
if not tmp_path or not os.path.exists(tmp_path) or os.path.getsize(tmp_path) <= 0:
|
|
raise HTTPException(status_code=400, detail="audio payload is empty")
|
|
|
|
resolved_language = str(language or "").strip() or speech_settings["default_language"]
|
|
result = await asyncio.to_thread(self._speech_service.transcribe_file, tmp_path, resolved_language)
|
|
text = str(result.get("text") or "").strip()
|
|
if not text:
|
|
raise HTTPException(status_code=400, detail="No speech detected")
|
|
return {
|
|
"bot_id": bot_id,
|
|
"text": text,
|
|
"duration_seconds": result.get("duration_seconds"),
|
|
"max_audio_seconds": speech_settings["max_audio_seconds"],
|
|
"model": speech_settings["model"],
|
|
"device": speech_settings["device"],
|
|
"language": result.get("language") or resolved_language,
|
|
}
|
|
except SpeechDisabledError as exc:
|
|
self._logger.warning(
|
|
"speech transcribe disabled bot_id=%s file=%s language=%s detail=%s",
|
|
bot_id,
|
|
safe_name,
|
|
language,
|
|
exc,
|
|
)
|
|
raise HTTPException(status_code=400, detail=str(exc))
|
|
except SpeechDurationError:
|
|
self._logger.warning(
|
|
"speech transcribe too long bot_id=%s file=%s language=%s max_seconds=%s",
|
|
bot_id,
|
|
safe_name,
|
|
language,
|
|
speech_settings["max_audio_seconds"],
|
|
)
|
|
raise HTTPException(status_code=413, detail=f"Audio duration exceeds {speech_settings['max_audio_seconds']} seconds")
|
|
except SpeechServiceError as exc:
|
|
self._logger.exception(
|
|
"speech transcribe failed bot_id=%s file=%s language=%s",
|
|
bot_id,
|
|
safe_name,
|
|
language,
|
|
)
|
|
raise HTTPException(status_code=400, detail=str(exc))
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
self._logger.exception(
|
|
"speech transcribe unexpected error bot_id=%s file=%s language=%s",
|
|
bot_id,
|
|
safe_name,
|
|
language,
|
|
)
|
|
raise HTTPException(status_code=500, detail=f"speech transcription failed: {exc}")
|
|
finally:
|
|
try:
|
|
await file.close()
|
|
except Exception:
|
|
pass
|
|
if tmp_path and os.path.exists(tmp_path):
|
|
try:
|
|
os.remove(tmp_path)
|
|
except Exception:
|
|
pass
|