556 lines
21 KiB
Python
556 lines
21 KiB
Python
import json
|
||
import os
|
||
from typing import Optional, Dict, Generator, Any, List
|
||
|
||
import httpx
|
||
|
||
from app.core.database import get_db_connection
|
||
from app.services.system_config_service import SystemConfigService
|
||
|
||
|
||
class LLMServiceError(Exception):
|
||
"""LLM 调用失败时抛出的结构化异常。"""
|
||
|
||
def __init__(self, message: str, *, status_code: Optional[int] = None):
|
||
super().__init__(message)
|
||
self.message = message
|
||
self.status_code = status_code
|
||
|
||
|
||
class LLMService:
|
||
"""LLM服务 - 专注于大模型API调用和提示词管理"""
|
||
|
||
@staticmethod
|
||
def _use_system_proxy() -> bool:
|
||
return os.getenv("IMEETING_USE_SYSTEM_PROXY", "").lower() in {"1", "true", "yes", "on"}
|
||
|
||
@staticmethod
|
||
def _create_httpx_client() -> httpx.Client:
|
||
return httpx.Client(
|
||
trust_env=LLMService._use_system_proxy()
|
||
)
|
||
|
||
@staticmethod
|
||
def _coerce_int(value: Any, default: int, minimum: Optional[int] = None) -> int:
|
||
try:
|
||
normalized = int(value)
|
||
except (TypeError, ValueError):
|
||
normalized = default
|
||
if minimum is not None:
|
||
normalized = max(minimum, normalized)
|
||
return normalized
|
||
|
||
@staticmethod
|
||
def _coerce_float(value: Any, default: float) -> float:
|
||
try:
|
||
return float(value)
|
||
except (TypeError, ValueError):
|
||
return default
|
||
|
||
@staticmethod
|
||
def _build_timeout(timeout_seconds: int) -> httpx.Timeout:
|
||
normalized_timeout = max(1, int(timeout_seconds))
|
||
connect_timeout = min(10.0, float(normalized_timeout))
|
||
return httpx.Timeout(
|
||
connect=connect_timeout,
|
||
read=float(normalized_timeout),
|
||
write=float(normalized_timeout),
|
||
pool=connect_timeout,
|
||
)
|
||
|
||
@staticmethod
|
||
def _normalize_api_key(api_key: Optional[Any]) -> Optional[str]:
|
||
if api_key is None:
|
||
return None
|
||
normalized = str(api_key).strip()
|
||
return normalized or None
|
||
|
||
@staticmethod
|
||
def _normalize_model_code(model_code: Optional[Any]) -> Optional[str]:
|
||
if model_code is None:
|
||
return None
|
||
normalized = str(model_code).strip()
|
||
return normalized or None
|
||
|
||
@classmethod
|
||
def build_call_params_from_config(cls, config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
config = config or {}
|
||
endpoint_url = str(config.get("endpoint_url") or SystemConfigService.get_llm_endpoint_url() or "").strip()
|
||
api_key = cls._normalize_api_key(config.get("api_key"))
|
||
if api_key is None:
|
||
api_key = cls._normalize_api_key(SystemConfigService.get_llm_api_key())
|
||
|
||
default_model = SystemConfigService.get_llm_model_name()
|
||
default_timeout = SystemConfigService.get_llm_timeout()
|
||
default_temperature = SystemConfigService.get_llm_temperature()
|
||
default_top_p = SystemConfigService.get_llm_top_p()
|
||
default_max_tokens = SystemConfigService.get_llm_max_tokens()
|
||
default_system_prompt = SystemConfigService.get_llm_system_prompt(None)
|
||
|
||
return {
|
||
"endpoint_url": endpoint_url,
|
||
"api_key": api_key,
|
||
"model": str(
|
||
config.get("llm_model_name")
|
||
or config.get("model")
|
||
or config.get("model_name")
|
||
or default_model
|
||
).strip(),
|
||
"timeout": cls._coerce_int(
|
||
config.get("llm_timeout")
|
||
or config.get("timeout")
|
||
or config.get("time_out")
|
||
or default_timeout,
|
||
default_timeout,
|
||
minimum=1,
|
||
),
|
||
"temperature": cls._coerce_float(
|
||
config.get("llm_temperature") if config.get("llm_temperature") is not None else config.get("temperature"),
|
||
default_temperature,
|
||
),
|
||
"top_p": cls._coerce_float(
|
||
config.get("llm_top_p") if config.get("llm_top_p") is not None else config.get("top_p"),
|
||
default_top_p,
|
||
),
|
||
"max_tokens": cls._coerce_int(
|
||
config.get("llm_max_tokens") or config.get("max_tokens") or default_max_tokens,
|
||
default_max_tokens,
|
||
minimum=1,
|
||
),
|
||
"system_prompt": config.get("llm_system_prompt") or config.get("system_prompt") or default_system_prompt,
|
||
}
|
||
|
||
def _get_llm_call_params(self) -> Dict[str, Any]:
|
||
"""
|
||
获取 OpenAI 兼容接口调用参数
|
||
|
||
Returns:
|
||
Dict: 包含 endpoint_url、api_key、model、timeout、temperature、top_p、max_tokens 的参数字典
|
||
"""
|
||
return self.build_call_params_from_config()
|
||
|
||
@staticmethod
|
||
def _build_chat_url(endpoint_url: str) -> str:
|
||
base_url = (endpoint_url or "").rstrip("/")
|
||
if base_url.endswith("/chat/completions"):
|
||
return base_url
|
||
return f"{base_url}/chat/completions"
|
||
|
||
@staticmethod
|
||
def _build_headers(api_key: Optional[str]) -> Dict[str, str]:
|
||
headers = {"Content-Type": "application/json"}
|
||
if api_key:
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
return headers
|
||
|
||
def _normalize_messages(
|
||
self,
|
||
prompt: Optional[str] = None,
|
||
messages: Optional[List[Dict[str, Any]]] = None,
|
||
) -> List[Dict[str, str]]:
|
||
normalized_messages: List[Dict[str, str]] = []
|
||
|
||
if messages is not None:
|
||
for message in messages:
|
||
if not isinstance(message, dict):
|
||
continue
|
||
role = str(message.get("role") or "").strip()
|
||
if not role:
|
||
continue
|
||
content = self._normalize_content(message.get("content"))
|
||
if not content:
|
||
continue
|
||
normalized_messages.append({"role": role, "content": content})
|
||
return normalized_messages
|
||
|
||
if prompt is not None:
|
||
prompt_content = self._normalize_content(prompt)
|
||
if prompt_content:
|
||
normalized_messages.append({"role": "user", "content": prompt_content})
|
||
|
||
return normalized_messages
|
||
|
||
@staticmethod
|
||
def _merge_system_messages(
|
||
messages: List[Dict[str, str]],
|
||
base_system_prompt: Optional[str],
|
||
) -> List[Dict[str, str]]:
|
||
merged_messages: List[Dict[str, str]] = []
|
||
merged_system_parts: List[str] = []
|
||
|
||
if isinstance(base_system_prompt, str) and base_system_prompt.strip():
|
||
merged_system_parts.append(base_system_prompt.strip())
|
||
|
||
index = 0
|
||
while index < len(messages) and messages[index].get("role") == "system":
|
||
content = str(messages[index].get("content") or "").strip()
|
||
if content:
|
||
merged_system_parts.append(content)
|
||
index += 1
|
||
|
||
if merged_system_parts:
|
||
merged_messages.append({"role": "system", "content": "\n\n".join(merged_system_parts)})
|
||
|
||
merged_messages.extend(messages[index:])
|
||
return merged_messages
|
||
|
||
def _build_payload(
|
||
self,
|
||
prompt: Optional[str] = None,
|
||
messages: Optional[List[Dict[str, Any]]] = None,
|
||
stream: bool = False,
|
||
params: Optional[Dict[str, Any]] = None,
|
||
) -> Dict[str, Any]:
|
||
params = params or self._get_llm_call_params()
|
||
normalized_messages = self._normalize_messages(prompt=prompt, messages=messages)
|
||
normalized_messages = self._merge_system_messages(normalized_messages, params.get("system_prompt"))
|
||
|
||
if not normalized_messages:
|
||
raise ValueError("缺少 prompt 或 messages")
|
||
|
||
payload = {
|
||
"model": params["model"],
|
||
"messages": normalized_messages,
|
||
"temperature": params["temperature"],
|
||
"top_p": params["top_p"],
|
||
"max_tokens": params["max_tokens"],
|
||
"stream": stream,
|
||
}
|
||
return payload
|
||
|
||
@staticmethod
|
||
def _normalize_content(content: Any) -> str:
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
texts = []
|
||
for item in content:
|
||
if isinstance(item, str):
|
||
texts.append(item)
|
||
elif isinstance(item, dict):
|
||
text = item.get("text")
|
||
if text:
|
||
texts.append(text)
|
||
return "".join(texts)
|
||
return ""
|
||
|
||
def _extract_response_text(self, data: Dict[str, Any]) -> str:
|
||
choices = data.get("choices") or []
|
||
if not choices:
|
||
return ""
|
||
|
||
first_choice = choices[0] or {}
|
||
message = first_choice.get("message") or {}
|
||
content = message.get("content")
|
||
if content:
|
||
return self._normalize_content(content)
|
||
|
||
delta = first_choice.get("delta") or {}
|
||
delta_content = delta.get("content")
|
||
if delta_content:
|
||
return self._normalize_content(delta_content)
|
||
|
||
return ""
|
||
|
||
def _validate_call_params(self, params: Dict[str, Any]) -> Optional[str]:
|
||
if not params.get("endpoint_url"):
|
||
return "缺少 endpoint_url"
|
||
if not params.get("model"):
|
||
return "缺少 model"
|
||
if not params.get("api_key"):
|
||
return "缺少API Key"
|
||
return None
|
||
|
||
@staticmethod
|
||
def _extract_error_message_from_response(response: httpx.Response) -> str:
|
||
try:
|
||
payload = response.json()
|
||
except ValueError:
|
||
payload = None
|
||
|
||
if isinstance(payload, dict):
|
||
error = payload.get("error")
|
||
if isinstance(error, dict):
|
||
parts = [
|
||
str(error.get("message") or "").strip(),
|
||
str(error.get("type") or "").strip(),
|
||
str(error.get("code") or "").strip(),
|
||
]
|
||
message = " / ".join(part for part in parts if part)
|
||
if message:
|
||
return message
|
||
if isinstance(error, str) and error.strip():
|
||
return error.strip()
|
||
message = payload.get("message")
|
||
if isinstance(message, str) and message.strip():
|
||
return message.strip()
|
||
|
||
text = (response.text or "").strip()
|
||
return text[:500] if text else f"HTTP {response.status_code}"
|
||
|
||
def get_call_params_by_model_code(self, model_code: Optional[str] = None) -> Dict[str, Any]:
|
||
normalized_model_code = self._normalize_model_code(model_code)
|
||
if not normalized_model_code:
|
||
return self._get_llm_call_params()
|
||
|
||
runtime_config = SystemConfigService.get_model_runtime_config(normalized_model_code)
|
||
if not runtime_config:
|
||
raise LLMServiceError(f"指定模型不可用: {normalized_model_code}")
|
||
|
||
return self.build_call_params_from_config(runtime_config)
|
||
|
||
def _resolve_call_params(
|
||
self,
|
||
model_code: Optional[str] = None,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
) -> Dict[str, Any]:
|
||
if config is not None:
|
||
return self.build_call_params_from_config(config)
|
||
return self.get_call_params_by_model_code(model_code)
|
||
|
||
def get_task_prompt(self, task_type: str, cursor=None, prompt_id: Optional[int] = None) -> str:
|
||
"""
|
||
统一的提示词获取方法
|
||
|
||
Args:
|
||
task_type: 任务类型,如 'MEETING_TASK', 'KNOWLEDGE_TASK' 等
|
||
cursor: 数据库游标,如果传入则使用,否则创建新连接
|
||
prompt_id: 可选的提示词ID,如果指定则使用该提示词,否则使用默认提示词
|
||
|
||
Returns:
|
||
str: 提示词内容
|
||
"""
|
||
# 如果指定了 prompt_id,直接获取该提示词
|
||
if prompt_id:
|
||
query = """
|
||
SELECT content
|
||
FROM prompts
|
||
WHERE id = %s AND task_type = %s AND is_active = TRUE
|
||
LIMIT 1
|
||
"""
|
||
params = (prompt_id, task_type)
|
||
else:
|
||
# 否则获取默认提示词
|
||
query = """
|
||
SELECT content
|
||
FROM prompts
|
||
WHERE task_type = %s
|
||
AND is_default = TRUE
|
||
AND is_active = TRUE
|
||
LIMIT 1
|
||
"""
|
||
params = (task_type,)
|
||
|
||
if cursor:
|
||
cursor.execute(query, params)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
return result['content'] if isinstance(result, dict) else result[0]
|
||
else:
|
||
with get_db_connection() as connection:
|
||
cursor = connection.cursor(dictionary=True)
|
||
cursor.execute(query, params)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
return result['content']
|
||
|
||
prompt_label = f"ID={prompt_id}" if prompt_id else f"task_type={task_type} 的默认模版"
|
||
raise LLMServiceError(f"未找到可用提示词模版:{prompt_label}")
|
||
|
||
def stream_llm_api(
|
||
self,
|
||
prompt: Optional[str] = None,
|
||
model_code: Optional[str] = None,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
messages: Optional[List[Dict[str, Any]]] = None,
|
||
) -> Generator[str, None, None]:
|
||
"""流式调用 OpenAI 兼容大模型API。"""
|
||
try:
|
||
params = self._resolve_call_params(model_code=model_code, config=config)
|
||
validation_error = self._validate_call_params(params)
|
||
if validation_error:
|
||
yield f"error: {validation_error}"
|
||
return
|
||
|
||
timeout = self._build_timeout(params["timeout"])
|
||
with self._create_httpx_client() as client:
|
||
with client.stream(
|
||
"POST",
|
||
self._build_chat_url(params["endpoint_url"]),
|
||
headers=self._build_headers(params["api_key"]),
|
||
json=self._build_payload(prompt=prompt, messages=messages, stream=True, params=params),
|
||
timeout=timeout,
|
||
) as response:
|
||
response.raise_for_status()
|
||
|
||
for line in response.iter_lines():
|
||
if not line or not line.startswith("data:"):
|
||
continue
|
||
|
||
data_line = line[5:].strip()
|
||
if not data_line or data_line == "[DONE]":
|
||
continue
|
||
|
||
try:
|
||
data = json.loads(data_line)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
new_content = self._extract_response_text(data)
|
||
if new_content:
|
||
yield new_content
|
||
except LLMServiceError as e:
|
||
error_msg = e.message or str(e)
|
||
print(f"流式调用大模型API错误: {error_msg}")
|
||
yield f"error: {error_msg}"
|
||
except httpx.HTTPStatusError as e:
|
||
detail = self._extract_error_message_from_response(e.response)
|
||
error_msg = f"流式调用大模型API错误: HTTP {e.response.status_code} - {detail}"
|
||
print(error_msg)
|
||
yield f"error: {error_msg}"
|
||
except httpx.TimeoutException:
|
||
error_msg = f"流式调用大模型API超时: timeout={params['timeout']}s"
|
||
print(error_msg)
|
||
yield f"error: {error_msg}"
|
||
except httpx.RequestError as e:
|
||
error_msg = f"流式调用大模型API网络错误: {e}"
|
||
print(error_msg)
|
||
yield f"error: {error_msg}"
|
||
except Exception as e:
|
||
error_msg = f"流式调用大模型API错误: {e}"
|
||
print(error_msg)
|
||
yield f"error: {error_msg}"
|
||
|
||
def call_llm_api(
|
||
self,
|
||
prompt: Optional[str] = None,
|
||
model_code: Optional[str] = None,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
messages: Optional[List[Dict[str, Any]]] = None,
|
||
) -> Optional[str]:
|
||
"""调用 OpenAI 兼容大模型API(非流式)。"""
|
||
try:
|
||
return self.call_llm_api_or_raise(
|
||
prompt=prompt,
|
||
model_code=model_code,
|
||
config=config,
|
||
messages=messages,
|
||
)
|
||
except LLMServiceError as e:
|
||
print(f"调用大模型API错误: {e}")
|
||
return None
|
||
|
||
def call_llm_api_or_raise(
|
||
self,
|
||
prompt: Optional[str] = None,
|
||
model_code: Optional[str] = None,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
messages: Optional[List[Dict[str, Any]]] = None,
|
||
) -> str:
|
||
"""调用 OpenAI 兼容大模型API(非流式),失败时抛出结构化异常。"""
|
||
params = self._resolve_call_params(model_code=model_code, config=config)
|
||
return self.call_llm_api_with_config_or_raise(params, prompt=prompt, messages=messages)
|
||
|
||
def call_llm_api_messages(
|
||
self,
|
||
messages: List[Dict[str, Any]],
|
||
model_code: Optional[str] = None,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
) -> Optional[str]:
|
||
"""使用多消息结构调用 OpenAI 兼容大模型API(非流式)。"""
|
||
return self.call_llm_api(prompt=None, model_code=model_code, config=config, messages=messages)
|
||
|
||
def call_llm_api_messages_or_raise(
|
||
self,
|
||
messages: List[Dict[str, Any]],
|
||
model_code: Optional[str] = None,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
) -> str:
|
||
"""使用多消息结构调用 OpenAI 兼容大模型API(非流式),失败时抛出结构化异常。"""
|
||
return self.call_llm_api_or_raise(prompt=None, model_code=model_code, config=config, messages=messages)
|
||
|
||
def call_llm_api_with_config(
|
||
self,
|
||
params: Dict[str, Any],
|
||
prompt: Optional[str] = None,
|
||
messages: Optional[List[Dict[str, Any]]] = None,
|
||
) -> Optional[str]:
|
||
"""使用指定配置调用 OpenAI 兼容大模型API(非流式)"""
|
||
try:
|
||
return self.call_llm_api_with_config_or_raise(params, prompt=prompt, messages=messages)
|
||
except LLMServiceError as e:
|
||
print(f"调用大模型API错误: {e}")
|
||
return None
|
||
|
||
def call_llm_api_with_config_or_raise(
|
||
self,
|
||
params: Dict[str, Any],
|
||
prompt: Optional[str] = None,
|
||
messages: Optional[List[Dict[str, Any]]] = None,
|
||
) -> str:
|
||
"""使用指定配置调用 OpenAI 兼容大模型API(非流式),失败时抛出结构化异常。"""
|
||
validation_error = self._validate_call_params(params)
|
||
if validation_error:
|
||
raise LLMServiceError(validation_error)
|
||
|
||
timeout = self._build_timeout(params["timeout"])
|
||
try:
|
||
with self._create_httpx_client() as client:
|
||
response = client.post(
|
||
self._build_chat_url(params["endpoint_url"]),
|
||
headers=self._build_headers(params["api_key"]),
|
||
json=self._build_payload(prompt=prompt, messages=messages, params=params),
|
||
timeout=timeout,
|
||
)
|
||
response.raise_for_status()
|
||
content = self._extract_response_text(response.json())
|
||
if content:
|
||
return content
|
||
raise LLMServiceError("API调用失败: 返回内容为空")
|
||
except httpx.HTTPStatusError as e:
|
||
detail = self._extract_error_message_from_response(e.response)
|
||
raise LLMServiceError(
|
||
f"HTTP {e.response.status_code} - {detail}",
|
||
status_code=e.response.status_code,
|
||
) from e
|
||
except httpx.TimeoutException:
|
||
raise LLMServiceError(f"调用超时: timeout={params['timeout']}s")
|
||
except httpx.RequestError as e:
|
||
raise LLMServiceError(f"网络错误: {e}") from e
|
||
except Exception as e:
|
||
raise LLMServiceError(str(e)) from e
|
||
|
||
def test_model(self, config: Dict[str, Any], prompt: Optional[str] = None) -> Dict[str, Any]:
|
||
params = self.build_call_params_from_config(config)
|
||
test_prompt = prompt or "请用一句中文回复:LLM测试成功。"
|
||
content = self.call_llm_api_with_config_or_raise(params, test_prompt)
|
||
if not content:
|
||
raise Exception("模型无有效返回内容")
|
||
|
||
return {
|
||
"model": params["model"],
|
||
"endpoint_url": params["endpoint_url"],
|
||
"response_preview": content[:500],
|
||
"used_params": {
|
||
"timeout": params["timeout"],
|
||
"temperature": params["temperature"],
|
||
"top_p": params["top_p"],
|
||
"max_tokens": params["max_tokens"],
|
||
},
|
||
}
|
||
|
||
|
||
# 测试代码
|
||
if __name__ == '__main__':
|
||
print("--- 运行LLM服务测试 ---")
|
||
llm_service = LLMService()
|
||
|
||
# 测试获取任务提示词
|
||
meeting_prompt = llm_service.get_task_prompt('MEETING_TASK')
|
||
print(f"会议任务提示词: {meeting_prompt[:100]}...")
|
||
|
||
knowledge_prompt = llm_service.get_task_prompt('KNOWLEDGE_TASK')
|
||
print(f"知识库任务提示词: {knowledge_prompt[:100]}...")
|
||
|
||
print("--- LLM服务测试完成 ---")
|