imetting/backend/app/services/llm_service.py

565 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import json
import os
from typing import Optional, Dict, Generator, Any, List
import httpx
import app.core.config as config_module
from app.core.database import get_db_connection
from app.services.system_config_service import SystemConfigService
class LLMServiceError(Exception):
"""LLM 调用失败时抛出的结构化异常。"""
def __init__(self, message: str, *, status_code: Optional[int] = None):
super().__init__(message)
self.message = message
self.status_code = status_code
class LLMService:
"""LLM服务 - 专注于大模型API调用和提示词管理"""
@staticmethod
def _use_system_proxy() -> bool:
return os.getenv("IMEETING_USE_SYSTEM_PROXY", "").lower() in {"1", "true", "yes", "on"}
@staticmethod
def _create_httpx_client() -> httpx.Client:
return httpx.Client(
trust_env=LLMService._use_system_proxy()
)
@staticmethod
def _coerce_int(value: Any, default: int, minimum: Optional[int] = None) -> int:
try:
normalized = int(value)
except (TypeError, ValueError):
normalized = default
if minimum is not None:
normalized = max(minimum, normalized)
return normalized
@staticmethod
def _coerce_float(value: Any, default: float) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default
@staticmethod
def _build_timeout(timeout_seconds: int) -> httpx.Timeout:
normalized_timeout = max(1, int(timeout_seconds))
connect_timeout = min(10.0, float(normalized_timeout))
return httpx.Timeout(
connect=connect_timeout,
read=float(normalized_timeout),
write=float(normalized_timeout),
pool=connect_timeout,
)
@staticmethod
def _normalize_api_key(api_key: Optional[Any]) -> Optional[str]:
if api_key is None:
return None
normalized = str(api_key).strip()
return normalized or None
@staticmethod
def _normalize_model_code(model_code: Optional[Any]) -> Optional[str]:
if model_code is None:
return None
normalized = str(model_code).strip()
return normalized or None
@classmethod
def build_call_params_from_config(cls, config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
config = config or {}
endpoint_url = str(config.get("endpoint_url") or SystemConfigService.get_llm_endpoint_url() or "").strip()
api_key = cls._normalize_api_key(config.get("api_key"))
if api_key is None:
api_key = cls._normalize_api_key(SystemConfigService.get_llm_api_key(config_module.QWEN_API_KEY))
default_model = SystemConfigService.get_llm_model_name()
default_timeout = SystemConfigService.get_llm_timeout()
default_temperature = SystemConfigService.get_llm_temperature()
default_top_p = SystemConfigService.get_llm_top_p()
default_max_tokens = SystemConfigService.get_llm_max_tokens()
default_system_prompt = SystemConfigService.get_llm_system_prompt(None)
return {
"endpoint_url": endpoint_url,
"api_key": api_key,
"model": str(
config.get("llm_model_name")
or config.get("model")
or config.get("model_name")
or default_model
).strip(),
"timeout": cls._coerce_int(
config.get("llm_timeout")
or config.get("timeout")
or config.get("time_out")
or default_timeout,
default_timeout,
minimum=1,
),
"temperature": cls._coerce_float(
config.get("llm_temperature") if config.get("llm_temperature") is not None else config.get("temperature"),
default_temperature,
),
"top_p": cls._coerce_float(
config.get("llm_top_p") if config.get("llm_top_p") is not None else config.get("top_p"),
default_top_p,
),
"max_tokens": cls._coerce_int(
config.get("llm_max_tokens") or config.get("max_tokens") or default_max_tokens,
default_max_tokens,
minimum=1,
),
"system_prompt": config.get("llm_system_prompt") or config.get("system_prompt") or default_system_prompt,
}
def _get_llm_call_params(self) -> Dict[str, Any]:
"""
获取 OpenAI 兼容接口调用参数
Returns:
Dict: 包含 endpoint_url、api_key、model、timeout、temperature、top_p、max_tokens 的参数字典
"""
return self.build_call_params_from_config()
@staticmethod
def _build_chat_url(endpoint_url: str) -> str:
base_url = (endpoint_url or "").rstrip("/")
if base_url.endswith("/chat/completions"):
return base_url
return f"{base_url}/chat/completions"
@staticmethod
def _build_headers(api_key: Optional[str]) -> Dict[str, str]:
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def _normalize_messages(
self,
prompt: Optional[str] = None,
messages: Optional[List[Dict[str, Any]]] = None,
) -> List[Dict[str, str]]:
normalized_messages: List[Dict[str, str]] = []
if messages is not None:
for message in messages:
if not isinstance(message, dict):
continue
role = str(message.get("role") or "").strip()
if not role:
continue
content = self._normalize_content(message.get("content"))
if not content:
continue
normalized_messages.append({"role": role, "content": content})
return normalized_messages
if prompt is not None:
prompt_content = self._normalize_content(prompt)
if prompt_content:
normalized_messages.append({"role": "user", "content": prompt_content})
return normalized_messages
@staticmethod
def _merge_system_messages(
messages: List[Dict[str, str]],
base_system_prompt: Optional[str],
) -> List[Dict[str, str]]:
merged_messages: List[Dict[str, str]] = []
merged_system_parts: List[str] = []
if isinstance(base_system_prompt, str) and base_system_prompt.strip():
merged_system_parts.append(base_system_prompt.strip())
index = 0
while index < len(messages) and messages[index].get("role") == "system":
content = str(messages[index].get("content") or "").strip()
if content:
merged_system_parts.append(content)
index += 1
if merged_system_parts:
merged_messages.append({"role": "system", "content": "\n\n".join(merged_system_parts)})
merged_messages.extend(messages[index:])
return merged_messages
def _build_payload(
self,
prompt: Optional[str] = None,
messages: Optional[List[Dict[str, Any]]] = None,
stream: bool = False,
params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
params = params or self._get_llm_call_params()
normalized_messages = self._normalize_messages(prompt=prompt, messages=messages)
normalized_messages = self._merge_system_messages(normalized_messages, params.get("system_prompt"))
if not normalized_messages:
raise ValueError("缺少 prompt 或 messages")
payload = {
"model": params["model"],
"messages": normalized_messages,
"temperature": params["temperature"],
"top_p": params["top_p"],
"max_tokens": params["max_tokens"],
"stream": stream,
}
return payload
@staticmethod
def _normalize_content(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
texts = []
for item in content:
if isinstance(item, str):
texts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if text:
texts.append(text)
return "".join(texts)
return ""
def _extract_response_text(self, data: Dict[str, Any]) -> str:
choices = data.get("choices") or []
if not choices:
return ""
first_choice = choices[0] or {}
message = first_choice.get("message") or {}
content = message.get("content")
if content:
return self._normalize_content(content)
delta = first_choice.get("delta") or {}
delta_content = delta.get("content")
if delta_content:
return self._normalize_content(delta_content)
return ""
def _validate_call_params(self, params: Dict[str, Any]) -> Optional[str]:
if not params.get("endpoint_url"):
return "缺少 endpoint_url"
if not params.get("model"):
return "缺少 model"
if not params.get("api_key"):
return "缺少API Key"
return None
@staticmethod
def _extract_error_message_from_response(response: httpx.Response) -> str:
try:
payload = response.json()
except ValueError:
payload = None
if isinstance(payload, dict):
error = payload.get("error")
if isinstance(error, dict):
parts = [
str(error.get("message") or "").strip(),
str(error.get("type") or "").strip(),
str(error.get("code") or "").strip(),
]
message = " / ".join(part for part in parts if part)
if message:
return message
if isinstance(error, str) and error.strip():
return error.strip()
message = payload.get("message")
if isinstance(message, str) and message.strip():
return message.strip()
text = (response.text or "").strip()
return text[:500] if text else f"HTTP {response.status_code}"
def get_call_params_by_model_code(self, model_code: Optional[str] = None) -> Dict[str, Any]:
normalized_model_code = self._normalize_model_code(model_code)
if not normalized_model_code:
return self._get_llm_call_params()
runtime_config = SystemConfigService.get_model_runtime_config(normalized_model_code)
if not runtime_config:
raise LLMServiceError(f"指定模型不可用: {normalized_model_code}")
return self.build_call_params_from_config(runtime_config)
def _resolve_call_params(
self,
model_code: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
if config is not None:
return self.build_call_params_from_config(config)
return self.get_call_params_by_model_code(model_code)
def get_task_prompt(self, task_type: str, cursor=None, prompt_id: Optional[int] = None) -> str:
"""
统一的提示词获取方法
Args:
task_type: 任务类型,如 'MEETING_TASK', 'KNOWLEDGE_TASK'
cursor: 数据库游标,如果传入则使用,否则创建新连接
prompt_id: 可选的提示词ID如果指定则使用该提示词否则使用默认提示词
Returns:
str: 提示词内容
"""
# 如果指定了 prompt_id直接获取该提示词
if prompt_id:
query = """
SELECT content
FROM prompts
WHERE id = %s AND task_type = %s AND is_active = TRUE
LIMIT 1
"""
params = (prompt_id, task_type)
else:
# 否则获取默认提示词
query = """
SELECT content
FROM prompts
WHERE task_type = %s
AND is_default = TRUE
AND is_active = TRUE
LIMIT 1
"""
params = (task_type,)
if cursor:
cursor.execute(query, params)
result = cursor.fetchone()
if result:
return result['content'] if isinstance(result, dict) else result[0]
else:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
cursor.execute(query, params)
result = cursor.fetchone()
if result:
return result['content']
prompt_label = f"ID={prompt_id}" if prompt_id else f"task_type={task_type} 的默认模版"
raise LLMServiceError(f"未找到可用提示词模版:{prompt_label}")
def stream_llm_api(
self,
prompt: Optional[str] = None,
model_code: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
messages: Optional[List[Dict[str, Any]]] = None,
) -> Generator[str, None, None]:
"""流式调用 OpenAI 兼容大模型API。"""
try:
params = self._resolve_call_params(model_code=model_code, config=config)
validation_error = self._validate_call_params(params)
if validation_error:
yield f"error: {validation_error}"
return
timeout = self._build_timeout(params["timeout"])
with self._create_httpx_client() as client:
with client.stream(
"POST",
self._build_chat_url(params["endpoint_url"]),
headers=self._build_headers(params["api_key"]),
json=self._build_payload(prompt=prompt, messages=messages, stream=True, params=params),
timeout=timeout,
) as response:
response.raise_for_status()
for line in response.iter_lines():
if not line or not line.startswith("data:"):
continue
data_line = line[5:].strip()
if not data_line or data_line == "[DONE]":
continue
try:
data = json.loads(data_line)
except json.JSONDecodeError:
continue
new_content = self._extract_response_text(data)
if new_content:
yield new_content
except LLMServiceError as e:
error_msg = e.message or str(e)
print(f"流式调用大模型API错误: {error_msg}")
yield f"error: {error_msg}"
except httpx.HTTPStatusError as e:
detail = self._extract_error_message_from_response(e.response)
error_msg = f"流式调用大模型API错误: HTTP {e.response.status_code} - {detail}"
print(error_msg)
yield f"error: {error_msg}"
except httpx.TimeoutException:
error_msg = f"流式调用大模型API超时: timeout={params['timeout']}s"
print(error_msg)
yield f"error: {error_msg}"
except httpx.RequestError as e:
error_msg = f"流式调用大模型API网络错误: {e}"
print(error_msg)
yield f"error: {error_msg}"
except Exception as e:
error_msg = f"流式调用大模型API错误: {e}"
print(error_msg)
yield f"error: {error_msg}"
def _call_llm_api_stream(self, prompt: str) -> Generator[str, None, None]:
"""兼容旧调用入口。"""
return self.stream_llm_api(prompt)
def call_llm_api(
self,
prompt: Optional[str] = None,
model_code: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
messages: Optional[List[Dict[str, Any]]] = None,
) -> Optional[str]:
"""调用 OpenAI 兼容大模型API非流式"""
try:
return self.call_llm_api_or_raise(
prompt=prompt,
model_code=model_code,
config=config,
messages=messages,
)
except LLMServiceError as e:
print(f"调用大模型API错误: {e}")
return None
def call_llm_api_or_raise(
self,
prompt: Optional[str] = None,
model_code: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
messages: Optional[List[Dict[str, Any]]] = None,
) -> str:
"""调用 OpenAI 兼容大模型API非流式失败时抛出结构化异常。"""
params = self._resolve_call_params(model_code=model_code, config=config)
return self.call_llm_api_with_config_or_raise(params, prompt=prompt, messages=messages)
def call_llm_api_messages(
self,
messages: List[Dict[str, Any]],
model_code: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
) -> Optional[str]:
"""使用多消息结构调用 OpenAI 兼容大模型API非流式"""
return self.call_llm_api(prompt=None, model_code=model_code, config=config, messages=messages)
def call_llm_api_messages_or_raise(
self,
messages: List[Dict[str, Any]],
model_code: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
) -> str:
"""使用多消息结构调用 OpenAI 兼容大模型API非流式失败时抛出结构化异常。"""
return self.call_llm_api_or_raise(prompt=None, model_code=model_code, config=config, messages=messages)
def call_llm_api_with_config(
self,
params: Dict[str, Any],
prompt: Optional[str] = None,
messages: Optional[List[Dict[str, Any]]] = None,
) -> Optional[str]:
"""使用指定配置调用 OpenAI 兼容大模型API非流式"""
try:
return self.call_llm_api_with_config_or_raise(params, prompt=prompt, messages=messages)
except LLMServiceError as e:
print(f"调用大模型API错误: {e}")
return None
def call_llm_api_with_config_or_raise(
self,
params: Dict[str, Any],
prompt: Optional[str] = None,
messages: Optional[List[Dict[str, Any]]] = None,
) -> str:
"""使用指定配置调用 OpenAI 兼容大模型API非流式失败时抛出结构化异常。"""
validation_error = self._validate_call_params(params)
if validation_error:
raise LLMServiceError(validation_error)
timeout = self._build_timeout(params["timeout"])
try:
with self._create_httpx_client() as client:
response = client.post(
self._build_chat_url(params["endpoint_url"]),
headers=self._build_headers(params["api_key"]),
json=self._build_payload(prompt=prompt, messages=messages, params=params),
timeout=timeout,
)
response.raise_for_status()
content = self._extract_response_text(response.json())
if content:
return content
raise LLMServiceError("API调用失败: 返回内容为空")
except httpx.HTTPStatusError as e:
detail = self._extract_error_message_from_response(e.response)
raise LLMServiceError(
f"HTTP {e.response.status_code} - {detail}",
status_code=e.response.status_code,
) from e
except httpx.TimeoutException:
raise LLMServiceError(f"调用超时: timeout={params['timeout']}s")
except httpx.RequestError as e:
raise LLMServiceError(f"网络错误: {e}") from e
except Exception as e:
raise LLMServiceError(str(e)) from e
def _call_llm_api(self, prompt: str) -> Optional[str]:
"""兼容旧调用入口。"""
return self.call_llm_api(prompt)
def test_model(self, config: Dict[str, Any], prompt: Optional[str] = None) -> Dict[str, Any]:
params = self.build_call_params_from_config(config)
test_prompt = prompt or "请用一句中文回复LLM测试成功。"
content = self.call_llm_api_with_config_or_raise(params, test_prompt)
if not content:
raise Exception("模型无有效返回内容")
return {
"model": params["model"],
"endpoint_url": params["endpoint_url"],
"response_preview": content[:500],
"used_params": {
"timeout": params["timeout"],
"temperature": params["temperature"],
"top_p": params["top_p"],
"max_tokens": params["max_tokens"],
},
}
# 测试代码
if __name__ == '__main__':
print("--- 运行LLM服务测试 ---")
llm_service = LLMService()
# 测试获取任务提示词
meeting_prompt = llm_service.get_task_prompt('MEETING_TASK')
print(f"会议任务提示词: {meeting_prompt[:100]}...")
knowledge_prompt = llm_service.get_task_prompt('KNOWLEDGE_TASK')
print(f"知识库任务提示词: {knowledge_prompt[:100]}...")
print("--- LLM服务测试完成 ---")