imetting/backend/app/services/llm_service.py

279 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import json
import os
from typing import Optional, Dict, Generator, Any
import requests
import app.core.config as config_module
from app.core.database import get_db_connection
from app.services.system_config_service import SystemConfigService
class LLMService:
"""LLM服务 - 专注于大模型API调用和提示词管理"""
@staticmethod
def _create_requests_session() -> requests.Session:
session = requests.Session()
session.trust_env = os.getenv("IMEETING_USE_SYSTEM_PROXY", "").lower() in {"1", "true", "yes", "on"}
return session
@staticmethod
def build_call_params_from_config(config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
config = config or {}
endpoint_url = config.get("endpoint_url") or SystemConfigService.get_llm_endpoint_url()
api_key = config.get("api_key")
if api_key is None:
api_key = SystemConfigService.get_llm_api_key(config_module.QWEN_API_KEY)
return {
"endpoint_url": endpoint_url,
"api_key": api_key,
"model": config.get("llm_model_name") or config.get("model") or SystemConfigService.get_llm_model_name(),
"timeout": int(config.get("llm_timeout") or config.get("timeout") or SystemConfigService.get_llm_timeout()),
"temperature": float(config.get("llm_temperature") if config.get("llm_temperature") is not None else config.get("temperature", SystemConfigService.get_llm_temperature())),
"top_p": float(config.get("llm_top_p") if config.get("llm_top_p") is not None else config.get("top_p", SystemConfigService.get_llm_top_p())),
"max_tokens": int(config.get("llm_max_tokens") or config.get("max_tokens") or SystemConfigService.get_llm_max_tokens()),
"system_prompt": config.get("llm_system_prompt") or config.get("system_prompt") or SystemConfigService.get_llm_system_prompt(None),
}
def _get_llm_call_params(self) -> Dict[str, Any]:
"""
获取 OpenAI 兼容接口调用参数
Returns:
Dict: 包含 endpoint_url、api_key、model、timeout、temperature、top_p、max_tokens 的参数字典
"""
return self.build_call_params_from_config()
@staticmethod
def _build_chat_url(endpoint_url: str) -> str:
base_url = (endpoint_url or "").rstrip("/")
if base_url.endswith("/chat/completions"):
return base_url
return f"{base_url}/chat/completions"
@staticmethod
def _build_headers(api_key: Optional[str]) -> Dict[str, str]:
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def _build_payload(self, prompt: str, stream: bool = False, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
params = params or self._get_llm_call_params()
messages = []
system_prompt = params.get("system_prompt")
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
payload = {
"model": params["model"],
"messages": messages,
"temperature": params["temperature"],
"top_p": params["top_p"],
"max_tokens": params["max_tokens"],
"stream": stream,
}
return payload
@staticmethod
def _normalize_content(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
texts = []
for item in content:
if isinstance(item, str):
texts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if text:
texts.append(text)
return "".join(texts)
return ""
def _extract_response_text(self, data: Dict[str, Any]) -> str:
choices = data.get("choices") or []
if not choices:
return ""
first_choice = choices[0] or {}
message = first_choice.get("message") or {}
content = message.get("content")
if content:
return self._normalize_content(content)
delta = first_choice.get("delta") or {}
delta_content = delta.get("content")
if delta_content:
return self._normalize_content(delta_content)
return ""
def get_task_prompt(self, task_type: str, cursor=None, prompt_id: Optional[int] = None) -> str:
"""
统一的提示词获取方法
Args:
task_type: 任务类型,如 'MEETING_TASK', 'KNOWLEDGE_TASK'
cursor: 数据库游标,如果传入则使用,否则创建新连接
prompt_id: 可选的提示词ID如果指定则使用该提示词否则使用默认提示词
Returns:
str: 提示词内容,如果未找到返回默认提示词
"""
# 如果指定了 prompt_id直接获取该提示词
if prompt_id:
query = """
SELECT content
FROM prompts
WHERE id = %s AND task_type = %s AND is_active = TRUE
LIMIT 1
"""
params = (prompt_id, task_type)
else:
# 否则获取默认提示词
query = """
SELECT content
FROM prompts
WHERE task_type = %s
AND is_default = TRUE
AND is_active = TRUE
LIMIT 1
"""
params = (task_type,)
if cursor:
cursor.execute(query, params)
result = cursor.fetchone()
if result:
return result['content'] if isinstance(result, dict) else result[0]
else:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
cursor.execute(query, params)
result = cursor.fetchone()
if result:
return result['content']
# 返回默认提示词
return self._get_default_prompt(task_type)
def _get_default_prompt(self, task_name: str) -> str:
"""获取默认提示词"""
system_prompt = SystemConfigService.get_llm_system_prompt("请根据提供的内容进行总结和分析。")
default_prompts = {
'MEETING_TASK': system_prompt,
'KNOWLEDGE_TASK': "请根据提供的信息生成知识库文章。",
}
return default_prompts.get(task_name, "请根据提供的内容进行总结和分析。")
def _call_llm_api_stream(self, prompt: str) -> Generator[str, None, None]:
"""流式调用 OpenAI 兼容大模型API"""
params = self._get_llm_call_params()
if not params["api_key"]:
yield "error: 缺少API Key"
return
try:
session = self._create_requests_session()
try:
response = session.post(
self._build_chat_url(params["endpoint_url"]),
headers=self._build_headers(params["api_key"]),
json=self._build_payload(prompt, stream=True),
timeout=params["timeout"],
stream=True,
)
response.raise_for_status()
for line in response.iter_lines(decode_unicode=True):
if not line or not line.startswith("data:"):
continue
data_line = line[5:].strip()
if not data_line or data_line == "[DONE]":
continue
try:
data = json.loads(data_line)
except json.JSONDecodeError:
continue
new_content = self._extract_response_text(data)
if new_content:
yield new_content
finally:
session.close()
except Exception as e:
error_msg = f"流式调用大模型API错误: {e}"
print(error_msg)
yield f"error: {error_msg}"
def _call_llm_api(self, prompt: str) -> Optional[str]:
"""调用 OpenAI 兼容大模型API非流式"""
params = self._get_llm_call_params()
return self.call_llm_api_with_config(params, prompt)
def call_llm_api_with_config(self, params: Dict[str, Any], prompt: str) -> Optional[str]:
"""使用指定配置调用 OpenAI 兼容大模型API非流式"""
if not params["api_key"]:
print("调用大模型API错误: 缺少API Key")
return None
try:
session = self._create_requests_session()
try:
response = session.post(
self._build_chat_url(params["endpoint_url"]),
headers=self._build_headers(params["api_key"]),
json=self._build_payload(prompt, params=params),
timeout=params["timeout"],
)
response.raise_for_status()
content = self._extract_response_text(response.json())
finally:
session.close()
if content:
return content
print("API调用失败: 返回内容为空")
return None
except Exception as e:
print(f"调用大模型API错误: {e}")
return None
def test_model(self, config: Dict[str, Any], prompt: Optional[str] = None) -> Dict[str, Any]:
params = self.build_call_params_from_config(config)
test_prompt = prompt or "请用一句中文回复LLM测试成功。"
content = self.call_llm_api_with_config(params, test_prompt)
if not content:
raise Exception("模型无有效返回内容")
return {
"model": params["model"],
"endpoint_url": params["endpoint_url"],
"response_preview": content[:500],
"used_params": {
"timeout": params["timeout"],
"temperature": params["temperature"],
"top_p": params["top_p"],
"max_tokens": params["max_tokens"],
},
}
# 测试代码
if __name__ == '__main__':
print("--- 运行LLM服务测试 ---")
llm_service = LLMService()
# 测试获取任务提示词
meeting_prompt = llm_service.get_task_prompt('MEETING_TASK')
print(f"会议任务提示词: {meeting_prompt[:100]}...")
knowledge_prompt = llm_service.get_task_prompt('KNOWLEDGE_TASK')
print(f"知识库任务提示词: {knowledge_prompt[:100]}...")
print("--- LLM服务测试完成 ---")