279 lines
11 KiB
Python
279 lines
11 KiB
Python
import json
|
||
import os
|
||
from typing import Optional, Dict, Generator, Any
|
||
|
||
import requests
|
||
|
||
import app.core.config as config_module
|
||
from app.core.database import get_db_connection
|
||
from app.services.system_config_service import SystemConfigService
|
||
|
||
|
||
class LLMService:
|
||
"""LLM服务 - 专注于大模型API调用和提示词管理"""
|
||
|
||
@staticmethod
|
||
def _create_requests_session() -> requests.Session:
|
||
session = requests.Session()
|
||
session.trust_env = os.getenv("IMEETING_USE_SYSTEM_PROXY", "").lower() in {"1", "true", "yes", "on"}
|
||
return session
|
||
|
||
@staticmethod
|
||
def build_call_params_from_config(config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
config = config or {}
|
||
endpoint_url = config.get("endpoint_url") or SystemConfigService.get_llm_endpoint_url()
|
||
api_key = config.get("api_key")
|
||
if api_key is None:
|
||
api_key = SystemConfigService.get_llm_api_key(config_module.QWEN_API_KEY)
|
||
|
||
return {
|
||
"endpoint_url": endpoint_url,
|
||
"api_key": api_key,
|
||
"model": config.get("llm_model_name") or config.get("model") or SystemConfigService.get_llm_model_name(),
|
||
"timeout": int(config.get("llm_timeout") or config.get("timeout") or SystemConfigService.get_llm_timeout()),
|
||
"temperature": float(config.get("llm_temperature") if config.get("llm_temperature") is not None else config.get("temperature", SystemConfigService.get_llm_temperature())),
|
||
"top_p": float(config.get("llm_top_p") if config.get("llm_top_p") is not None else config.get("top_p", SystemConfigService.get_llm_top_p())),
|
||
"max_tokens": int(config.get("llm_max_tokens") or config.get("max_tokens") or SystemConfigService.get_llm_max_tokens()),
|
||
"system_prompt": config.get("llm_system_prompt") or config.get("system_prompt") or SystemConfigService.get_llm_system_prompt(None),
|
||
}
|
||
|
||
def _get_llm_call_params(self) -> Dict[str, Any]:
|
||
"""
|
||
获取 OpenAI 兼容接口调用参数
|
||
|
||
Returns:
|
||
Dict: 包含 endpoint_url、api_key、model、timeout、temperature、top_p、max_tokens 的参数字典
|
||
"""
|
||
return self.build_call_params_from_config()
|
||
|
||
@staticmethod
|
||
def _build_chat_url(endpoint_url: str) -> str:
|
||
base_url = (endpoint_url or "").rstrip("/")
|
||
if base_url.endswith("/chat/completions"):
|
||
return base_url
|
||
return f"{base_url}/chat/completions"
|
||
|
||
@staticmethod
|
||
def _build_headers(api_key: Optional[str]) -> Dict[str, str]:
|
||
headers = {"Content-Type": "application/json"}
|
||
if api_key:
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
return headers
|
||
|
||
def _build_payload(self, prompt: str, stream: bool = False, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
params = params or self._get_llm_call_params()
|
||
messages = []
|
||
system_prompt = params.get("system_prompt")
|
||
if system_prompt:
|
||
messages.append({"role": "system", "content": system_prompt})
|
||
messages.append({"role": "user", "content": prompt})
|
||
payload = {
|
||
"model": params["model"],
|
||
"messages": messages,
|
||
"temperature": params["temperature"],
|
||
"top_p": params["top_p"],
|
||
"max_tokens": params["max_tokens"],
|
||
"stream": stream,
|
||
}
|
||
return payload
|
||
|
||
@staticmethod
|
||
def _normalize_content(content: Any) -> str:
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
texts = []
|
||
for item in content:
|
||
if isinstance(item, str):
|
||
texts.append(item)
|
||
elif isinstance(item, dict):
|
||
text = item.get("text")
|
||
if text:
|
||
texts.append(text)
|
||
return "".join(texts)
|
||
return ""
|
||
|
||
def _extract_response_text(self, data: Dict[str, Any]) -> str:
|
||
choices = data.get("choices") or []
|
||
if not choices:
|
||
return ""
|
||
|
||
first_choice = choices[0] or {}
|
||
message = first_choice.get("message") or {}
|
||
content = message.get("content")
|
||
if content:
|
||
return self._normalize_content(content)
|
||
|
||
delta = first_choice.get("delta") or {}
|
||
delta_content = delta.get("content")
|
||
if delta_content:
|
||
return self._normalize_content(delta_content)
|
||
|
||
return ""
|
||
|
||
def get_task_prompt(self, task_type: str, cursor=None, prompt_id: Optional[int] = None) -> str:
|
||
"""
|
||
统一的提示词获取方法
|
||
|
||
Args:
|
||
task_type: 任务类型,如 'MEETING_TASK', 'KNOWLEDGE_TASK' 等
|
||
cursor: 数据库游标,如果传入则使用,否则创建新连接
|
||
prompt_id: 可选的提示词ID,如果指定则使用该提示词,否则使用默认提示词
|
||
|
||
Returns:
|
||
str: 提示词内容,如果未找到返回默认提示词
|
||
"""
|
||
# 如果指定了 prompt_id,直接获取该提示词
|
||
if prompt_id:
|
||
query = """
|
||
SELECT content
|
||
FROM prompts
|
||
WHERE id = %s AND task_type = %s AND is_active = TRUE
|
||
LIMIT 1
|
||
"""
|
||
params = (prompt_id, task_type)
|
||
else:
|
||
# 否则获取默认提示词
|
||
query = """
|
||
SELECT content
|
||
FROM prompts
|
||
WHERE task_type = %s
|
||
AND is_default = TRUE
|
||
AND is_active = TRUE
|
||
LIMIT 1
|
||
"""
|
||
params = (task_type,)
|
||
|
||
if cursor:
|
||
cursor.execute(query, params)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
return result['content'] if isinstance(result, dict) else result[0]
|
||
else:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
cursor.execute(query, params)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
return result['content']
|
||
|
||
# 返回默认提示词
|
||
return self._get_default_prompt(task_type)
|
||
|
||
def _get_default_prompt(self, task_name: str) -> str:
|
||
"""获取默认提示词"""
|
||
system_prompt = SystemConfigService.get_llm_system_prompt("请根据提供的内容进行总结和分析。")
|
||
default_prompts = {
|
||
'MEETING_TASK': system_prompt,
|
||
'KNOWLEDGE_TASK': "请根据提供的信息生成知识库文章。",
|
||
}
|
||
return default_prompts.get(task_name, "请根据提供的内容进行总结和分析。")
|
||
|
||
def _call_llm_api_stream(self, prompt: str) -> Generator[str, None, None]:
|
||
"""流式调用 OpenAI 兼容大模型API"""
|
||
params = self._get_llm_call_params()
|
||
if not params["api_key"]:
|
||
yield "error: 缺少API Key"
|
||
return
|
||
|
||
try:
|
||
session = self._create_requests_session()
|
||
try:
|
||
response = session.post(
|
||
self._build_chat_url(params["endpoint_url"]),
|
||
headers=self._build_headers(params["api_key"]),
|
||
json=self._build_payload(prompt, stream=True),
|
||
timeout=params["timeout"],
|
||
stream=True,
|
||
)
|
||
response.raise_for_status()
|
||
|
||
for line in response.iter_lines(decode_unicode=True):
|
||
if not line or not line.startswith("data:"):
|
||
continue
|
||
|
||
data_line = line[5:].strip()
|
||
if not data_line or data_line == "[DONE]":
|
||
continue
|
||
|
||
try:
|
||
data = json.loads(data_line)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
new_content = self._extract_response_text(data)
|
||
if new_content:
|
||
yield new_content
|
||
finally:
|
||
session.close()
|
||
except Exception as e:
|
||
error_msg = f"流式调用大模型API错误: {e}"
|
||
print(error_msg)
|
||
yield f"error: {error_msg}"
|
||
|
||
def _call_llm_api(self, prompt: str) -> Optional[str]:
|
||
"""调用 OpenAI 兼容大模型API(非流式)"""
|
||
params = self._get_llm_call_params()
|
||
return self.call_llm_api_with_config(params, prompt)
|
||
|
||
def call_llm_api_with_config(self, params: Dict[str, Any], prompt: str) -> Optional[str]:
|
||
"""使用指定配置调用 OpenAI 兼容大模型API(非流式)"""
|
||
if not params["api_key"]:
|
||
print("调用大模型API错误: 缺少API Key")
|
||
return None
|
||
|
||
try:
|
||
session = self._create_requests_session()
|
||
try:
|
||
response = session.post(
|
||
self._build_chat_url(params["endpoint_url"]),
|
||
headers=self._build_headers(params["api_key"]),
|
||
json=self._build_payload(prompt, params=params),
|
||
timeout=params["timeout"],
|
||
)
|
||
response.raise_for_status()
|
||
content = self._extract_response_text(response.json())
|
||
finally:
|
||
session.close()
|
||
if content:
|
||
return content
|
||
print("API调用失败: 返回内容为空")
|
||
return None
|
||
except Exception as e:
|
||
print(f"调用大模型API错误: {e}")
|
||
return None
|
||
|
||
def test_model(self, config: Dict[str, Any], prompt: Optional[str] = None) -> Dict[str, Any]:
|
||
params = self.build_call_params_from_config(config)
|
||
test_prompt = prompt or "请用一句中文回复:LLM测试成功。"
|
||
content = self.call_llm_api_with_config(params, test_prompt)
|
||
if not content:
|
||
raise Exception("模型无有效返回内容")
|
||
|
||
return {
|
||
"model": params["model"],
|
||
"endpoint_url": params["endpoint_url"],
|
||
"response_preview": content[:500],
|
||
"used_params": {
|
||
"timeout": params["timeout"],
|
||
"temperature": params["temperature"],
|
||
"top_p": params["top_p"],
|
||
"max_tokens": params["max_tokens"],
|
||
},
|
||
}
|
||
|
||
|
||
# 测试代码
|
||
if __name__ == '__main__':
|
||
print("--- 运行LLM服务测试 ---")
|
||
llm_service = LLMService()
|
||
|
||
# 测试获取任务提示词
|
||
meeting_prompt = llm_service.get_task_prompt('MEETING_TASK')
|
||
print(f"会议任务提示词: {meeting_prompt[:100]}...")
|
||
|
||
knowledge_prompt = llm_service.get_task_prompt('KNOWLEDGE_TASK')
|
||
print(f"知识库任务提示词: {knowledge_prompt[:100]}...")
|
||
|
||
print("--- LLM服务测试完成 ---")
|