imetting_backend/app/api/endpoints/prompts.py

241 lines
9.6 KiB
Python
Raw Normal View History

2025-09-30 04:14:19 +00:00
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from typing import List, Optional
from app.core.auth import get_current_user
2025-09-30 04:14:19 +00:00
from app.core.database import get_db_connection
from app.core.response import create_api_response
router = APIRouter()
# Pydantic Models
class PromptIn(BaseModel):
name: str
2025-12-11 08:48:12 +00:00
task_type: str # 'MEETING_TASK' 或 'KNOWLEDGE_TASK'
2025-09-30 04:14:19 +00:00
content: str
2025-12-11 08:48:12 +00:00
is_default: bool = False
is_active: bool = True
2025-09-30 04:14:19 +00:00
class PromptOut(PromptIn):
id: int
2025-12-11 08:48:12 +00:00
creator_id: int
2025-09-30 04:14:19 +00:00
created_at: str
class PromptListResponse(BaseModel):
prompts: List[PromptOut]
total: int
@router.post("/prompts")
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_user)):
2025-09-30 04:14:19 +00:00
"""Create a new prompt."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
try:
2025-12-11 08:48:12 +00:00
# 如果设置为默认,需要先取消同类型其他提示词的默认状态
if prompt.is_default:
cursor.execute(
"UPDATE prompts SET is_default = FALSE WHERE task_type = %s",
(prompt.task_type,)
)
2025-09-30 04:14:19 +00:00
cursor.execute(
2025-12-11 08:48:12 +00:00
"""INSERT INTO prompts (name, task_type, content, is_default, is_active, creator_id)
VALUES (%s, %s, %s, %s, %s, %s)""",
(prompt.name, prompt.task_type, prompt.content, prompt.is_default,
prompt.is_active, current_user["user_id"])
2025-09-30 04:14:19 +00:00
)
connection.commit()
new_id = cursor.lastrowid
2025-12-11 08:48:12 +00:00
return create_api_response(
code="200",
message="提示词创建成功",
data={"id": new_id, **prompt.dict()}
)
2025-09-30 04:14:19 +00:00
except Exception as e:
2025-12-11 08:48:12 +00:00
if "Duplicate entry" in str(e):
2025-09-30 04:14:19 +00:00
return create_api_response(code="400", message="提示词名称已存在")
return create_api_response(code="500", message=f"创建提示词失败: {e}")
2025-12-11 08:48:12 +00:00
@router.get("/prompts/active/{task_type}")
def get_active_prompts(task_type: str, current_user: dict = Depends(get_current_user)):
"""Get all active prompts for a specific task type."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
cursor.execute(
"""SELECT id, name, is_default
FROM prompts
WHERE task_type = %s AND is_active = TRUE
ORDER BY is_default DESC, created_at DESC""",
(task_type,)
)
prompts = cursor.fetchall()
return create_api_response(
code="200",
message="获取启用模版列表成功",
data={"prompts": prompts}
)
2025-09-30 04:14:19 +00:00
@router.get("/prompts")
2025-12-11 08:48:12 +00:00
def get_prompts(
task_type: Optional[str] = None,
page: int = 1,
size: int = 50,
current_user: dict = Depends(get_current_user)
):
"""Get a paginated list of prompts filtered by current user and optionally by task_type."""
2025-09-30 04:14:19 +00:00
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
2025-12-11 08:48:12 +00:00
# 构建 WHERE 条件
where_conditions = ["creator_id = %s"]
params = [current_user["user_id"]]
if task_type:
where_conditions.append("task_type = %s")
params.append(task_type)
where_clause = " AND ".join(where_conditions)
# 获取总数
cursor.execute(
2025-12-11 08:48:12 +00:00
f"SELECT COUNT(*) as total FROM prompts WHERE {where_clause}",
tuple(params)
)
2025-09-30 04:14:19 +00:00
total = cursor.fetchone()['total']
2025-12-11 08:48:12 +00:00
# 获取分页数据
2025-09-30 04:14:19 +00:00
offset = (page - 1) * size
cursor.execute(
2025-12-11 08:48:12 +00:00
f"""SELECT id, name, task_type, content, is_default, is_active, creator_id, created_at
FROM prompts
WHERE {where_clause}
ORDER BY created_at DESC
LIMIT %s OFFSET %s""",
tuple(params + [size, offset])
2025-09-30 04:14:19 +00:00
)
prompts = cursor.fetchall()
2025-12-11 08:48:12 +00:00
return create_api_response(
code="200",
message="获取提示词列表成功",
data={"prompts": prompts, "total": total}
)
2025-09-30 04:14:19 +00:00
@router.get("/prompts/{prompt_id}")
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
2025-09-30 04:14:19 +00:00
"""Get a single prompt by its ID."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
2025-12-11 08:48:12 +00:00
cursor.execute(
"""SELECT id, name, task_type, content, is_default, is_active, creator_id, created_at
FROM prompts WHERE id = %s""",
(prompt_id,)
)
2025-09-30 04:14:19 +00:00
prompt = cursor.fetchone()
if not prompt:
return create_api_response(code="404", message="提示词不存在")
return create_api_response(code="200", message="获取提示词成功", data=prompt)
@router.put("/prompts/{prompt_id}")
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)):
2025-09-30 04:14:19 +00:00
"""Update an existing prompt."""
2025-12-11 08:48:12 +00:00
print(f"[UPDATE PROMPT] prompt_id={prompt_id}, type={type(prompt_id)}")
print(f"[UPDATE PROMPT] user_id={current_user['user_id']}")
print(f"[UPDATE PROMPT] data: name={prompt.name}, task_type={prompt.task_type}, content_len={len(prompt.content)}, is_default={prompt.is_default}, is_active={prompt.is_active}")
2025-09-30 04:14:19 +00:00
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
try:
2025-12-11 08:48:12 +00:00
# 先检查记录是否存在
cursor.execute("SELECT id, creator_id FROM prompts WHERE id = %s", (prompt_id,))
existing = cursor.fetchone()
print(f"[UPDATE PROMPT] existing record: {existing}")
if not existing:
print(f"[UPDATE PROMPT] Prompt {prompt_id} not found in database")
return create_api_response(code="404", message="提示词不存在")
# 如果设置为默认,需要先取消同类型其他提示词的默认状态
if prompt.is_default:
print(f"[UPDATE PROMPT] Setting as default, clearing other defaults for task_type={prompt.task_type}")
cursor.execute(
"UPDATE prompts SET is_default = FALSE WHERE task_type = %s AND id != %s",
(prompt.task_type, prompt_id)
)
print(f"[UPDATE PROMPT] Cleared {cursor.rowcount} other default prompts")
print(f"[UPDATE PROMPT] Executing UPDATE query")
2025-09-30 04:14:19 +00:00
cursor.execute(
2025-12-11 08:48:12 +00:00
"""UPDATE prompts
SET name = %s, task_type = %s, content = %s, is_default = %s, is_active = %s
WHERE id = %s""",
(prompt.name, prompt.task_type, prompt.content, prompt.is_default,
prompt.is_active, prompt_id)
2025-09-30 04:14:19 +00:00
)
2025-12-11 08:48:12 +00:00
rows_affected = cursor.rowcount
print(f"[UPDATE PROMPT] UPDATE affected {rows_affected} rows (0 means no changes needed)")
# 注意rowcount=0 不代表记录不存在,可能是所有字段值都相同
# 我们已经在上面确认了记录存在,所以这里直接提交即可
2025-09-30 04:14:19 +00:00
connection.commit()
2025-12-11 08:48:12 +00:00
print(f"[UPDATE PROMPT] Success! Committed changes")
2025-09-30 04:14:19 +00:00
return create_api_response(code="200", message="提示词更新成功")
except Exception as e:
2025-12-11 08:48:12 +00:00
print(f"[UPDATE PROMPT] Exception: {type(e).__name__}: {e}")
if "Duplicate entry" in str(e):
2025-09-30 04:14:19 +00:00
return create_api_response(code="400", message="提示词名称已存在")
return create_api_response(code="500", message=f"更新提示词失败: {e}")
@router.delete("/prompts/{prompt_id}")
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
"""Delete a prompt. Only the creator can delete their own prompts."""
2025-09-30 04:14:19 +00:00
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
# 首先检查提示词是否存在以及是否属于当前用户
cursor.execute(
"SELECT creator_id FROM prompts WHERE id = %s",
(prompt_id,)
)
prompt = cursor.fetchone()
if not prompt:
2025-09-30 04:14:19 +00:00
return create_api_response(code="404", message="提示词不存在")
if prompt['creator_id'] != current_user["user_id"]:
return create_api_response(code="403", message="无权删除其他用户的提示词")
2025-12-16 10:55:31 +00:00
# 检查是否有会议引用了该提示词
cursor.execute(
"SELECT COUNT(*) as count FROM meetings WHERE prompt_id = %s",
(prompt_id,)
)
meeting_count = cursor.fetchone()['count']
# 检查是否有知识库引用了该提示词
cursor.execute(
"SELECT COUNT(*) as count FROM knowledge_bases WHERE prompt_id = %s",
(prompt_id,)
)
kb_count = cursor.fetchone()['count']
# 如果有引用,不允许删除
if meeting_count > 0 or kb_count > 0:
references = []
if meeting_count > 0:
references.append(f"{meeting_count}个会议")
if kb_count > 0:
references.append(f"{kb_count}个知识库")
return create_api_response(
code="400",
message=f"无法删除:该提示词被{''.join(references)}引用",
data={
"meeting_count": meeting_count,
"kb_count": kb_count
}
)
# 删除提示词
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
2025-09-30 04:14:19 +00:00
connection.commit()
return create_api_response(code="200", message="提示词删除成功")