imetting_backend/test/test_prompt_id_feature.py

177 lines
6.1 KiB
Python
Raw Normal View History

2025-12-11 08:48:12 +00:00
"""
测试提示词模版选择功能
"""
import sys
sys.path.insert(0, 'app')
from app.services.llm_service import LLMService
from app.services.async_meeting_service import AsyncMeetingService
from app.core.database import get_db_connection
def test_get_active_prompts():
"""测试获取启用的提示词列表"""
print("\n=== 测试1: 获取启用的提示词列表 ===")
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 获取MEETING_TASK类型的启用模版
query = """
SELECT id, name, is_default
FROM prompts
WHERE task_type = %s AND is_active = TRUE
ORDER BY is_default DESC, created_at DESC
"""
cursor.execute(query, ('MEETING_TASK',))
prompts = cursor.fetchall()
print(f"✓ 找到 {len(prompts)} 个启用的会议任务模版:")
for p in prompts:
default_flag = " [默认]" if p['is_default'] else ""
print(f" - ID: {p['id']}, 名称: {p['name']}{default_flag}")
return prompts
except Exception as e:
print(f"✗ 测试失败: {e}")
import traceback
traceback.print_exc()
return []
def test_get_task_prompt_with_id(prompts):
"""测试通过prompt_id获取提示词内容"""
print("\n=== 测试2: 通过prompt_id获取提示词内容 ===")
if not prompts:
print("⚠ 没有可用的提示词模版,跳过测试")
return
llm_service = LLMService()
# 测试获取第一个提示词
test_prompt = prompts[0]
try:
content = llm_service.get_task_prompt('MEETING_TASK', prompt_id=test_prompt['id'])
print(f"✓ 成功获取提示词 ID={test_prompt['id']}, 名称={test_prompt['name']}")
print(f" 内容长度: {len(content)} 字符")
print(f" 内容预览: {content[:100]}...")
except Exception as e:
print(f"✗ 测试失败: {e}")
import traceback
traceback.print_exc()
# 测试获取默认提示词不指定prompt_id
try:
default_content = llm_service.get_task_prompt('MEETING_TASK')
print(f"✓ 成功获取默认提示词")
print(f" 内容长度: {len(default_content)} 字符")
except Exception as e:
print(f"✗ 获取默认提示词失败: {e}")
def test_async_meeting_service_signature():
"""测试async_meeting_service的方法签名"""
print("\n=== 测试3: 验证方法签名支持prompt_id参数 ===")
import inspect
async_service = AsyncMeetingService()
# 检查start_summary_generation方法签名
sig = inspect.signature(async_service.start_summary_generation)
params = list(sig.parameters.keys())
if 'prompt_id' in params:
print(f"✓ start_summary_generation 方法支持 prompt_id 参数")
print(f" 参数列表: {params}")
else:
print(f"✗ start_summary_generation 方法缺少 prompt_id 参数")
print(f" 参数列表: {params}")
# 检查monitor_and_auto_summarize方法签名
sig2 = inspect.signature(async_service.monitor_and_auto_summarize)
params2 = list(sig2.parameters.keys())
if 'prompt_id' in params2:
print(f"✓ monitor_and_auto_summarize 方法支持 prompt_id 参数")
print(f" 参数列表: {params2}")
else:
print(f"✗ monitor_and_auto_summarize 方法缺少 prompt_id 参数")
print(f" 参数列表: {params2}")
def test_database_schema():
"""测试数据库schema是否包含prompt_id列"""
print("\n=== 测试4: 验证数据库schema ===")
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 检查llm_tasks表是否有prompt_id列
cursor.execute("""
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = 'llm_tasks'
AND COLUMN_NAME = 'prompt_id'
""")
result = cursor.fetchone()
if result:
print(f"✓ llm_tasks 表包含 prompt_id 列")
print(f" 类型: {result['DATA_TYPE']}")
print(f" 可空: {result['IS_NULLABLE']}")
print(f" 默认值: {result['COLUMN_DEFAULT']}")
else:
print(f"✗ llm_tasks 表缺少 prompt_id 列")
except Exception as e:
print(f"✗ 数据库检查失败: {e}")
import traceback
traceback.print_exc()
def test_api_endpoints():
"""测试API端点定义"""
print("\n=== 测试5: 验证API端点定义 ===")
try:
from app.api.endpoints.meetings import GenerateSummaryRequest
import inspect
# 检查GenerateSummaryRequest模型
fields = GenerateSummaryRequest.__fields__
if 'prompt_id' in fields:
print(f"✓ GenerateSummaryRequest 包含 prompt_id 字段")
print(f" 字段列表: {list(fields.keys())}")
else:
print(f"✗ GenerateSummaryRequest 缺少 prompt_id 字段")
print(f" 字段列表: {list(fields.keys())}")
# 检查audio_service.handle_audio_upload签名
from app.services.audio_service import handle_audio_upload
sig = inspect.signature(handle_audio_upload)
params = list(sig.parameters.keys())
if 'prompt_id' in params:
print(f"✓ handle_audio_upload 方法支持 prompt_id 参数")
else:
print(f"✗ handle_audio_upload 方法缺少 prompt_id 参数")
except Exception as e:
print(f"✗ API端点检查失败: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
print("=" * 60)
print("开始测试提示词模版选择功能")
print("=" * 60)
# 运行所有测试
prompts = test_get_active_prompts()
test_get_task_prompt_with_id(prompts)
test_async_meeting_service_signature()
test_database_schema()
test_api_endpoints()
print("\n" + "=" * 60)
print("测试完成")
print("=" * 60)