2026-01-19 11:03:08 +00:00
|
|
|
|
from fastapi import APIRouter, Depends
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
|
|
|
|
from app.core.auth import get_current_user
|
|
|
|
|
|
from app.core.database import get_db_connection
|
|
|
|
|
|
from app.core.response import create_api_response
|
|
|
|
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
# Pydantic Models
|
|
|
|
|
|
class PromptIn(BaseModel):
|
|
|
|
|
|
name: str
|
|
|
|
|
|
task_type: str # 'MEETING_TASK' 或 'KNOWLEDGE_TASK'
|
|
|
|
|
|
content: str
|
2026-02-12 07:34:12 +00:00
|
|
|
|
desc: Optional[str] = None # 模版描述
|
2026-01-19 11:03:08 +00:00
|
|
|
|
is_default: bool = False
|
|
|
|
|
|
is_active: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
class PromptOut(PromptIn):
|
|
|
|
|
|
id: int
|
|
|
|
|
|
creator_id: int
|
|
|
|
|
|
created_at: str
|
|
|
|
|
|
|
|
|
|
|
|
class PromptListResponse(BaseModel):
|
|
|
|
|
|
prompts: List[PromptOut]
|
|
|
|
|
|
total: int
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/prompts")
|
|
|
|
|
|
def create_prompt(prompt: PromptIn, current_user: dict = Depends(get_current_user)):
|
|
|
|
|
|
"""Create a new prompt."""
|
|
|
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 如果设置为默认,需要先取消同类型其他提示词的默认状态
|
|
|
|
|
|
if prompt.is_default:
|
|
|
|
|
|
cursor.execute(
|
|
|
|
|
|
"UPDATE prompts SET is_default = FALSE WHERE task_type = %s",
|
|
|
|
|
|
(prompt.task_type,)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
cursor.execute(
|
2026-02-12 07:34:12 +00:00
|
|
|
|
"""INSERT INTO prompts (name, task_type, content, desc, is_default, is_active, creator_id)
|
|
|
|
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s)""",
|
|
|
|
|
|
(prompt.name, prompt.task_type, prompt.content, prompt.desc,
|
|
|
|
|
|
prompt.is_default, prompt.is_active, current_user["user_id"])
|
2026-01-19 11:03:08 +00:00
|
|
|
|
)
|
|
|
|
|
|
connection.commit()
|
|
|
|
|
|
new_id = cursor.lastrowid
|
|
|
|
|
|
return create_api_response(
|
|
|
|
|
|
code="200",
|
|
|
|
|
|
message="提示词创建成功",
|
|
|
|
|
|
data={"id": new_id, **prompt.dict()}
|
|
|
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
if "Duplicate entry" in str(e):
|
|
|
|
|
|
return create_api_response(code="400", message="提示词名称已存在")
|
|
|
|
|
|
return create_api_response(code="500", message=f"创建提示词失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/prompts/active/{task_type}")
|
|
|
|
|
|
def get_active_prompts(task_type: str, current_user: dict = Depends(get_current_user)):
|
2026-02-12 07:34:12 +00:00
|
|
|
|
"""Get all active prompts for a specific task type.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
- All active prompts created by administrators (role_id = 1)
|
|
|
|
|
|
- All active prompts created by the current logged-in user
|
|
|
|
|
|
"""
|
2026-01-19 11:03:08 +00:00
|
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
|
|
cursor.execute(
|
2026-02-12 07:34:12 +00:00
|
|
|
|
"""SELECT id, name, desc, content, is_default
|
2026-01-19 11:03:08 +00:00
|
|
|
|
FROM prompts
|
|
|
|
|
|
WHERE task_type = %s AND is_active = TRUE
|
2026-02-12 07:34:12 +00:00
|
|
|
|
AND (creator_id = 1 OR creator_id = %s)
|
2026-01-19 11:03:08 +00:00
|
|
|
|
ORDER BY is_default DESC, created_at DESC""",
|
2026-02-12 07:34:12 +00:00
|
|
|
|
(task_type, current_user["user_id"])
|
2026-01-19 11:03:08 +00:00
|
|
|
|
)
|
|
|
|
|
|
prompts = cursor.fetchall()
|
|
|
|
|
|
return create_api_response(
|
|
|
|
|
|
code="200",
|
|
|
|
|
|
message="获取启用模版列表成功",
|
|
|
|
|
|
data={"prompts": prompts}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/prompts")
|
|
|
|
|
|
def get_prompts(
|
|
|
|
|
|
task_type: Optional[str] = None,
|
|
|
|
|
|
page: int = 1,
|
|
|
|
|
|
size: int = 50,
|
|
|
|
|
|
current_user: dict = Depends(get_current_user)
|
|
|
|
|
|
):
|
|
|
|
|
|
"""Get a paginated list of prompts filtered by current user and optionally by task_type."""
|
|
|
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建 WHERE 条件
|
|
|
|
|
|
where_conditions = ["creator_id = %s"]
|
|
|
|
|
|
params = [current_user["user_id"]]
|
|
|
|
|
|
|
|
|
|
|
|
if task_type:
|
|
|
|
|
|
where_conditions.append("task_type = %s")
|
|
|
|
|
|
params.append(task_type)
|
|
|
|
|
|
|
|
|
|
|
|
where_clause = " AND ".join(where_conditions)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取总数
|
|
|
|
|
|
cursor.execute(
|
|
|
|
|
|
f"SELECT COUNT(*) as total FROM prompts WHERE {where_clause}",
|
|
|
|
|
|
tuple(params)
|
|
|
|
|
|
)
|
|
|
|
|
|
total = cursor.fetchone()['total']
|
|
|
|
|
|
|
|
|
|
|
|
# 获取分页数据
|
|
|
|
|
|
offset = (page - 1) * size
|
|
|
|
|
|
cursor.execute(
|
2026-02-12 07:34:12 +00:00
|
|
|
|
f"""SELECT id, name, task_type, content, desc, is_default, is_active, creator_id, created_at
|
2026-01-19 11:03:08 +00:00
|
|
|
|
FROM prompts
|
|
|
|
|
|
WHERE {where_clause}
|
|
|
|
|
|
ORDER BY created_at DESC
|
|
|
|
|
|
LIMIT %s OFFSET %s""",
|
|
|
|
|
|
tuple(params + [size, offset])
|
|
|
|
|
|
)
|
|
|
|
|
|
prompts = cursor.fetchall()
|
|
|
|
|
|
return create_api_response(
|
|
|
|
|
|
code="200",
|
|
|
|
|
|
message="获取提示词列表成功",
|
|
|
|
|
|
data={"prompts": prompts, "total": total}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/prompts/{prompt_id}")
|
|
|
|
|
|
def get_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
|
|
|
|
|
|
"""Get a single prompt by its ID."""
|
|
|
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
|
|
cursor.execute(
|
2026-02-12 07:34:12 +00:00
|
|
|
|
"""SELECT id, name, task_type, content, desc, is_default, is_active, creator_id, created_at
|
2026-01-19 11:03:08 +00:00
|
|
|
|
FROM prompts WHERE id = %s""",
|
|
|
|
|
|
(prompt_id,)
|
|
|
|
|
|
)
|
|
|
|
|
|
prompt = cursor.fetchone()
|
|
|
|
|
|
if not prompt:
|
|
|
|
|
|
return create_api_response(code="404", message="提示词不存在")
|
|
|
|
|
|
return create_api_response(code="200", message="获取提示词成功", data=prompt)
|
|
|
|
|
|
|
|
|
|
|
|
@router.put("/prompts/{prompt_id}")
|
|
|
|
|
|
def update_prompt(prompt_id: int, prompt: PromptIn, current_user: dict = Depends(get_current_user)):
|
|
|
|
|
|
"""Update an existing prompt."""
|
|
|
|
|
|
print(f"[UPDATE PROMPT] prompt_id={prompt_id}, type={type(prompt_id)}")
|
|
|
|
|
|
print(f"[UPDATE PROMPT] user_id={current_user['user_id']}")
|
|
|
|
|
|
print(f"[UPDATE PROMPT] data: name={prompt.name}, task_type={prompt.task_type}, content_len={len(prompt.content)}, is_default={prompt.is_default}, is_active={prompt.is_active}")
|
|
|
|
|
|
|
|
|
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 先检查记录是否存在
|
|
|
|
|
|
cursor.execute("SELECT id, creator_id FROM prompts WHERE id = %s", (prompt_id,))
|
|
|
|
|
|
existing = cursor.fetchone()
|
|
|
|
|
|
print(f"[UPDATE PROMPT] existing record: {existing}")
|
|
|
|
|
|
|
|
|
|
|
|
if not existing:
|
|
|
|
|
|
print(f"[UPDATE PROMPT] Prompt {prompt_id} not found in database")
|
|
|
|
|
|
return create_api_response(code="404", message="提示词不存在")
|
|
|
|
|
|
|
|
|
|
|
|
# 如果设置为默认,需要先取消同类型其他提示词的默认状态
|
|
|
|
|
|
if prompt.is_default:
|
|
|
|
|
|
print(f"[UPDATE PROMPT] Setting as default, clearing other defaults for task_type={prompt.task_type}")
|
|
|
|
|
|
cursor.execute(
|
|
|
|
|
|
"UPDATE prompts SET is_default = FALSE WHERE task_type = %s AND id != %s",
|
|
|
|
|
|
(prompt.task_type, prompt_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
print(f"[UPDATE PROMPT] Cleared {cursor.rowcount} other default prompts")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[UPDATE PROMPT] Executing UPDATE query")
|
|
|
|
|
|
cursor.execute(
|
|
|
|
|
|
"""UPDATE prompts
|
2026-02-12 07:34:12 +00:00
|
|
|
|
SET name = %s, task_type = %s, content = %s, desc = %s, is_default = %s, is_active = %s
|
2026-01-19 11:03:08 +00:00
|
|
|
|
WHERE id = %s""",
|
2026-02-12 07:34:12 +00:00
|
|
|
|
(prompt.name, prompt.task_type, prompt.content, prompt.desc, prompt.is_default,
|
2026-01-19 11:03:08 +00:00
|
|
|
|
prompt.is_active, prompt_id)
|
|
|
|
|
|
)
|
|
|
|
|
|
rows_affected = cursor.rowcount
|
|
|
|
|
|
print(f"[UPDATE PROMPT] UPDATE affected {rows_affected} rows (0 means no changes needed)")
|
|
|
|
|
|
|
|
|
|
|
|
# 注意:rowcount=0 不代表记录不存在,可能是所有字段值都相同
|
|
|
|
|
|
# 我们已经在上面确认了记录存在,所以这里直接提交即可
|
|
|
|
|
|
connection.commit()
|
|
|
|
|
|
print(f"[UPDATE PROMPT] Success! Committed changes")
|
|
|
|
|
|
return create_api_response(code="200", message="提示词更新成功")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[UPDATE PROMPT] Exception: {type(e).__name__}: {e}")
|
|
|
|
|
|
if "Duplicate entry" in str(e):
|
|
|
|
|
|
return create_api_response(code="400", message="提示词名称已存在")
|
|
|
|
|
|
return create_api_response(code="500", message=f"更新提示词失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/prompts/{prompt_id}")
|
|
|
|
|
|
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
|
|
|
|
|
|
"""Delete a prompt. Only the creator can delete their own prompts."""
|
|
|
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
|
|
# 首先检查提示词是否存在以及是否属于当前用户
|
|
|
|
|
|
cursor.execute(
|
|
|
|
|
|
"SELECT creator_id FROM prompts WHERE id = %s",
|
|
|
|
|
|
(prompt_id,)
|
|
|
|
|
|
)
|
|
|
|
|
|
prompt = cursor.fetchone()
|
|
|
|
|
|
|
|
|
|
|
|
if not prompt:
|
|
|
|
|
|
return create_api_response(code="404", message="提示词不存在")
|
|
|
|
|
|
|
|
|
|
|
|
if prompt['creator_id'] != current_user["user_id"]:
|
|
|
|
|
|
return create_api_response(code="403", message="无权删除其他用户的提示词")
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否有会议引用了该提示词
|
|
|
|
|
|
cursor.execute(
|
|
|
|
|
|
"SELECT COUNT(*) as count FROM meetings WHERE prompt_id = %s",
|
|
|
|
|
|
(prompt_id,)
|
|
|
|
|
|
)
|
|
|
|
|
|
meeting_count = cursor.fetchone()['count']
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否有知识库引用了该提示词
|
|
|
|
|
|
cursor.execute(
|
|
|
|
|
|
"SELECT COUNT(*) as count FROM knowledge_bases WHERE prompt_id = %s",
|
|
|
|
|
|
(prompt_id,)
|
|
|
|
|
|
)
|
|
|
|
|
|
kb_count = cursor.fetchone()['count']
|
|
|
|
|
|
|
|
|
|
|
|
# 如果有引用,不允许删除
|
|
|
|
|
|
if meeting_count > 0 or kb_count > 0:
|
|
|
|
|
|
references = []
|
|
|
|
|
|
if meeting_count > 0:
|
|
|
|
|
|
references.append(f"{meeting_count}个会议")
|
|
|
|
|
|
if kb_count > 0:
|
|
|
|
|
|
references.append(f"{kb_count}个知识库")
|
|
|
|
|
|
|
|
|
|
|
|
return create_api_response(
|
|
|
|
|
|
code="400",
|
|
|
|
|
|
message=f"无法删除:该提示词被{' 和 '.join(references)}引用",
|
|
|
|
|
|
data={
|
|
|
|
|
|
"meeting_count": meeting_count,
|
|
|
|
|
|
"kb_count": kb_count
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 删除提示词
|
|
|
|
|
|
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
|
|
|
|
|
|
connection.commit()
|
|
|
|
|
|
return create_api_response(code="200", message="提示词删除成功")
|