imetting_backend/test/test_kb_prompt_id_feature.py

167 lines
5.8 KiB
Python
Raw Normal View History

2025-12-11 08:48:12 +00:00
"""
测试知识库提示词模版选择功能
"""
import sys
sys.path.insert(0, 'app')
from app.services.llm_service import LLMService
from app.services.async_knowledge_base_service import AsyncKnowledgeBaseService
from app.core.database import get_db_connection
def test_get_active_knowledge_prompts():
"""测试获取启用的知识库提示词列表"""
print("\n=== 测试1: 获取启用的知识库提示词列表 ===")
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 获取KNOWLEDGE_TASK类型的启用模版
query = """
SELECT id, name, is_default
FROM prompts
WHERE task_type = %s AND is_active = TRUE
ORDER BY is_default DESC, created_at DESC
"""
cursor.execute(query, ('KNOWLEDGE_TASK',))
prompts = cursor.fetchall()
print(f"✓ 找到 {len(prompts)} 个启用的知识库任务模版:")
for p in prompts:
default_flag = " [默认]" if p['is_default'] else ""
print(f" - ID: {p['id']}, 名称: {p['name']}{default_flag}")
return prompts
except Exception as e:
print(f"✗ 测试失败: {e}")
import traceback
traceback.print_exc()
return []
def test_get_task_prompt_with_id(prompts):
"""测试通过prompt_id获取知识库提示词内容"""
print("\n=== 测试2: 通过prompt_id获取知识库提示词内容 ===")
if not prompts:
print("⚠ 没有可用的提示词模版,跳过测试")
return
llm_service = LLMService()
# 测试获取第一个提示词
test_prompt = prompts[0]
try:
content = llm_service.get_task_prompt('KNOWLEDGE_TASK', prompt_id=test_prompt['id'])
print(f"✓ 成功获取提示词 ID={test_prompt['id']}, 名称={test_prompt['name']}")
print(f" 内容长度: {len(content)} 字符")
print(f" 内容预览: {content[:100]}...")
except Exception as e:
print(f"✗ 测试失败: {e}")
import traceback
traceback.print_exc()
# 测试获取默认提示词不指定prompt_id
try:
default_content = llm_service.get_task_prompt('KNOWLEDGE_TASK')
print(f"✓ 成功获取默认提示词")
print(f" 内容长度: {len(default_content)} 字符")
except Exception as e:
print(f"✗ 获取默认提示词失败: {e}")
def test_async_kb_service_signature():
"""测试async_knowledge_base_service的方法签名"""
print("\n=== 测试3: 验证方法签名支持prompt_id参数 ===")
import inspect
async_service = AsyncKnowledgeBaseService()
# 检查start_generation方法签名
sig = inspect.signature(async_service.start_generation)
params = list(sig.parameters.keys())
if 'prompt_id' in params:
print(f"✓ start_generation 方法支持 prompt_id 参数")
print(f" 参数列表: {params}")
else:
print(f"✗ start_generation 方法缺少 prompt_id 参数")
print(f" 参数列表: {params}")
# 检查_build_prompt方法签名
sig2 = inspect.signature(async_service._build_prompt)
params2 = list(sig2.parameters.keys())
if 'prompt_id' in params2:
print(f"✓ _build_prompt 方法支持 prompt_id 参数")
print(f" 参数列表: {params2}")
else:
print(f"✗ _build_prompt 方法缺少 prompt_id 参数")
print(f" 参数列表: {params2}")
def test_database_schema():
"""测试数据库schema是否包含prompt_id列"""
print("\n=== 测试4: 验证数据库schema ===")
try:
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 检查knowledge_base_tasks表是否有prompt_id列
cursor.execute("""
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = 'knowledge_base_tasks'
AND COLUMN_NAME = 'prompt_id'
""")
result = cursor.fetchone()
if result:
print(f"✓ knowledge_base_tasks 表包含 prompt_id 列")
print(f" 类型: {result['DATA_TYPE']}")
print(f" 可空: {result['IS_NULLABLE']}")
print(f" 默认值: {result['COLUMN_DEFAULT']}")
else:
print(f"✗ knowledge_base_tasks 表缺少 prompt_id 列")
except Exception as e:
print(f"✗ 数据库检查失败: {e}")
import traceback
traceback.print_exc()
def test_api_model():
"""测试API模型定义"""
print("\n=== 测试5: 验证API模型定义 ===")
try:
from app.models.models import CreateKnowledgeBaseRequest
import inspect
# 检查CreateKnowledgeBaseRequest模型
fields = CreateKnowledgeBaseRequest.model_fields
if 'prompt_id' in fields:
print(f"✓ CreateKnowledgeBaseRequest 包含 prompt_id 字段")
print(f" 字段列表: {list(fields.keys())}")
else:
print(f"✗ CreateKnowledgeBaseRequest 缺少 prompt_id 字段")
print(f" 字段列表: {list(fields.keys())}")
except Exception as e:
print(f"✗ API模型检查失败: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
print("=" * 60)
print("开始测试知识库提示词模版选择功能")
print("=" * 60)
# 运行所有测试
prompts = test_get_active_knowledge_prompts()
test_get_task_prompt_with_id(prompts)
test_async_kb_service_signature()
test_database_schema()
test_api_model()
print("\n" + "=" * 60)
print("测试完成")
print("=" * 60)