2026-03-26 06:55:12 +00:00
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
2026-01-19 11:03:08 +00:00
|
|
|
from pydantic import BaseModel
|
2026-03-26 06:55:12 +00:00
|
|
|
from typing import Optional, List
|
2026-01-19 11:03:08 +00:00
|
|
|
|
|
|
|
|
from app.core.auth import get_current_user
|
|
|
|
|
from app.core.database import get_db_connection
|
|
|
|
|
from app.core.response import create_api_response
|
2026-03-26 06:55:12 +00:00
|
|
|
from app.models.models import PromptCreate, PromptUpdate
|
2026-01-19 11:03:08 +00:00
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
class PromptConfigItem(BaseModel):
|
|
|
|
|
prompt_id: int
|
|
|
|
|
is_enabled: bool = True
|
|
|
|
|
sort_order: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PromptConfigUpdateRequest(BaseModel):
|
|
|
|
|
items: List[PromptConfigItem]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_admin(user: dict) -> bool:
|
|
|
|
|
return int(user.get("role_id") or 0) == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _can_manage_prompt(current_user: dict, row: dict) -> bool:
|
|
|
|
|
if _is_admin(current_user):
|
|
|
|
|
return True
|
|
|
|
|
return int(row.get("creator_id") or 0) == int(current_user["user_id"]) and int(row.get("is_system") or 0) == 0
|
2026-01-19 11:03:08 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/prompts")
|
2026-03-26 06:55:12 +00:00
|
|
|
def create_prompt(
|
|
|
|
|
prompt: PromptCreate,
|
|
|
|
|
current_user: dict = Depends(get_current_user),
|
|
|
|
|
):
|
|
|
|
|
"""Create a prompt template. Admin can create system prompts, others can only create personal prompts."""
|
2026-01-19 11:03:08 +00:00
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
|
try:
|
2026-03-26 06:55:12 +00:00
|
|
|
is_admin = _is_admin(current_user)
|
|
|
|
|
requested_is_system = bool(getattr(prompt, "is_system", False))
|
|
|
|
|
is_system = 1 if (is_admin and requested_is_system) else 0
|
|
|
|
|
|
|
|
|
|
owner_user_id = current_user["user_id"]
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"""
|
|
|
|
|
SELECT COUNT(*) as cnt
|
|
|
|
|
FROM prompts
|
|
|
|
|
WHERE task_type = %s
|
|
|
|
|
AND is_system = %s
|
|
|
|
|
AND creator_id = %s
|
|
|
|
|
""",
|
|
|
|
|
(prompt.task_type, is_system, owner_user_id),
|
|
|
|
|
)
|
|
|
|
|
count = (cursor.fetchone() or {}).get("cnt", 0)
|
|
|
|
|
is_default = 1 if count == 0 else (1 if prompt.is_default else 0)
|
|
|
|
|
|
|
|
|
|
if is_default:
|
2026-01-19 11:03:08 +00:00
|
|
|
cursor.execute(
|
2026-03-26 06:55:12 +00:00
|
|
|
"""
|
|
|
|
|
UPDATE prompts
|
|
|
|
|
SET is_default = 0
|
|
|
|
|
WHERE task_type = %s
|
|
|
|
|
AND is_system = %s
|
|
|
|
|
AND creator_id = %s
|
|
|
|
|
""",
|
|
|
|
|
(prompt.task_type, is_system, owner_user_id),
|
2026-01-19 11:03:08 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
cursor.execute(
|
2026-03-26 06:55:12 +00:00
|
|
|
"""
|
|
|
|
|
INSERT INTO prompts (name, task_type, content, `desc`, is_default, is_active, creator_id, is_system)
|
|
|
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
|
|
|
|
""",
|
|
|
|
|
(
|
|
|
|
|
prompt.name,
|
|
|
|
|
prompt.task_type,
|
|
|
|
|
prompt.content,
|
|
|
|
|
prompt.desc,
|
|
|
|
|
is_default,
|
|
|
|
|
1 if prompt.is_active else 0,
|
|
|
|
|
owner_user_id,
|
|
|
|
|
is_system,
|
|
|
|
|
),
|
2026-01-19 11:03:08 +00:00
|
|
|
)
|
2026-03-26 06:55:12 +00:00
|
|
|
prompt_id = cursor.lastrowid
|
2026-01-19 11:03:08 +00:00
|
|
|
connection.commit()
|
2026-03-26 06:55:12 +00:00
|
|
|
return create_api_response(code="200", message="提示词模版创建成功", data={"id": prompt_id})
|
2026-01-19 11:03:08 +00:00
|
|
|
except Exception as e:
|
2026-03-26 06:55:12 +00:00
|
|
|
connection.rollback()
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
2026-01-19 11:03:08 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/prompts")
|
|
|
|
|
def get_prompts(
|
|
|
|
|
task_type: Optional[str] = None,
|
|
|
|
|
page: int = 1,
|
2026-03-26 06:55:12 +00:00
|
|
|
size: int = 12,
|
|
|
|
|
keyword: Optional[str] = Query(None),
|
|
|
|
|
is_active: Optional[int] = Query(None),
|
|
|
|
|
scope: str = Query("mine"), # mine / system / all / accessible
|
|
|
|
|
current_user: dict = Depends(get_current_user),
|
2026-01-19 11:03:08 +00:00
|
|
|
):
|
2026-03-26 06:55:12 +00:00
|
|
|
"""Get paginated prompt cards. Normal users can only view their own prompts."""
|
2026-01-19 11:03:08 +00:00
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
is_admin = _is_admin(current_user)
|
|
|
|
|
where_conditions = []
|
|
|
|
|
params = []
|
|
|
|
|
|
|
|
|
|
if scope == "all" and not is_admin:
|
|
|
|
|
scope = "accessible"
|
|
|
|
|
|
|
|
|
|
if scope == "system":
|
|
|
|
|
where_conditions.append("p.is_system = 1")
|
|
|
|
|
elif scope == "all":
|
|
|
|
|
where_conditions.append("(p.is_system = 1 OR p.creator_id = %s)")
|
|
|
|
|
params.append(current_user["user_id"])
|
|
|
|
|
elif scope == "accessible":
|
|
|
|
|
where_conditions.append("((p.is_system = 1 AND p.is_active = 1) OR (p.is_system = 0 AND p.creator_id = %s))")
|
|
|
|
|
params.append(current_user["user_id"])
|
|
|
|
|
else:
|
|
|
|
|
where_conditions.append("p.is_system = 0 AND p.creator_id = %s")
|
|
|
|
|
params.append(current_user["user_id"])
|
2026-01-19 11:03:08 +00:00
|
|
|
|
|
|
|
|
if task_type:
|
2026-03-26 06:55:12 +00:00
|
|
|
where_conditions.append("p.task_type = %s")
|
2026-01-19 11:03:08 +00:00
|
|
|
params.append(task_type)
|
|
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
if keyword:
|
|
|
|
|
where_conditions.append("(p.name LIKE %s OR p.`desc` LIKE %s)")
|
|
|
|
|
like = f"%{keyword}%"
|
|
|
|
|
params.extend([like, like])
|
|
|
|
|
if is_active in (0, 1):
|
|
|
|
|
where_conditions.append("p.is_active = %s")
|
|
|
|
|
params.append(is_active)
|
2026-01-19 11:03:08 +00:00
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
|
|
|
|
|
|
|
|
|
cursor.execute(f"SELECT COUNT(*) as total FROM prompts p WHERE {where_clause}", tuple(params))
|
|
|
|
|
total = (cursor.fetchone() or {}).get("total", 0)
|
2026-01-19 11:03:08 +00:00
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
offset = max(page - 1, 0) * size
|
2026-01-19 11:03:08 +00:00
|
|
|
cursor.execute(
|
2026-03-26 06:55:12 +00:00
|
|
|
f"""
|
|
|
|
|
SELECT p.id, p.name, p.task_type, p.content, p.`desc`, p.is_default, p.is_active,
|
|
|
|
|
p.creator_id, p.is_system, p.created_at,
|
|
|
|
|
u.caption AS creator_name
|
|
|
|
|
FROM prompts p
|
|
|
|
|
LEFT JOIN sys_users u ON u.user_id = p.creator_id
|
|
|
|
|
WHERE {where_clause}
|
|
|
|
|
ORDER BY p.is_system DESC, p.task_type ASC, p.is_default DESC, p.created_at DESC
|
|
|
|
|
LIMIT %s OFFSET %s
|
|
|
|
|
""",
|
|
|
|
|
tuple(params + [size, offset]),
|
2026-01-19 11:03:08 +00:00
|
|
|
)
|
2026-03-26 06:55:12 +00:00
|
|
|
rows = cursor.fetchall()
|
|
|
|
|
|
2026-01-19 11:03:08 +00:00
|
|
|
return create_api_response(
|
|
|
|
|
code="200",
|
|
|
|
|
message="获取提示词列表成功",
|
2026-03-26 06:55:12 +00:00
|
|
|
data={"prompts": rows, "total": total, "page": page, "size": size},
|
2026-01-19 11:03:08 +00:00
|
|
|
)
|
|
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
|
|
|
|
|
@router.get("/prompts/active/{task_type}")
|
|
|
|
|
def get_active_prompts(task_type: str, current_user: dict = Depends(get_current_user)):
|
|
|
|
|
"""
|
|
|
|
|
Active prompts for task selection.
|
|
|
|
|
Includes system prompts + personal prompts, and applies user's prompt config ordering.
|
|
|
|
|
"""
|
2026-01-19 11:03:08 +00:00
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
cursor = connection.cursor(dictionary=True)
|
2026-03-26 06:55:12 +00:00
|
|
|
|
2026-01-19 11:03:08 +00:00
|
|
|
cursor.execute(
|
2026-03-26 06:55:12 +00:00
|
|
|
"""
|
|
|
|
|
SELECT p.id, p.name, p.`desc`, p.content, p.is_default, p.is_system, p.creator_id,
|
|
|
|
|
cfg.is_enabled, cfg.sort_order
|
|
|
|
|
FROM prompts p
|
|
|
|
|
LEFT JOIN prompt_config cfg
|
|
|
|
|
ON cfg.prompt_id = p.id
|
|
|
|
|
AND cfg.user_id = %s
|
|
|
|
|
AND cfg.task_type = %s
|
|
|
|
|
WHERE p.task_type = %s
|
|
|
|
|
AND p.is_active = 1
|
|
|
|
|
AND (p.is_system = 1 OR p.creator_id = %s)
|
|
|
|
|
ORDER BY
|
|
|
|
|
CASE WHEN cfg.is_enabled = 1 THEN 0 ELSE 1 END,
|
|
|
|
|
cfg.sort_order ASC,
|
|
|
|
|
p.is_default DESC,
|
|
|
|
|
p.created_at DESC
|
|
|
|
|
""",
|
|
|
|
|
(current_user["user_id"], task_type, task_type, current_user["user_id"]),
|
2026-01-19 11:03:08 +00:00
|
|
|
)
|
2026-03-26 06:55:12 +00:00
|
|
|
prompts = cursor.fetchall()
|
2026-01-19 11:03:08 +00:00
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
enabled = [x for x in prompts if x.get("is_enabled") == 1]
|
|
|
|
|
if enabled:
|
|
|
|
|
result = enabled
|
|
|
|
|
else:
|
|
|
|
|
result = prompts
|
|
|
|
|
|
|
|
|
|
return create_api_response(code="200", message="获取启用模版列表成功", data={"prompts": result})
|
2026-01-19 11:03:08 +00:00
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
|
|
|
|
|
@router.get("/prompts/config/{task_type}")
|
|
|
|
|
def get_prompt_config(task_type: str, current_user: dict = Depends(get_current_user)):
|
2026-01-19 11:03:08 +00:00
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
cursor.execute(
|
|
|
|
|
"""
|
|
|
|
|
SELECT id, name, task_type, content, `desc`, is_default, is_active, is_system, creator_id, created_at
|
|
|
|
|
FROM prompts
|
|
|
|
|
WHERE task_type = %s
|
|
|
|
|
AND is_active = 1
|
|
|
|
|
AND (is_system = 1 OR creator_id = %s)
|
|
|
|
|
ORDER BY is_system DESC, is_default DESC, created_at DESC
|
|
|
|
|
""",
|
|
|
|
|
(task_type, current_user["user_id"]),
|
|
|
|
|
)
|
|
|
|
|
available = cursor.fetchall()
|
2026-01-19 11:03:08 +00:00
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
cursor.execute(
|
|
|
|
|
"""
|
|
|
|
|
SELECT prompt_id, is_enabled, sort_order
|
|
|
|
|
FROM prompt_config
|
|
|
|
|
WHERE user_id = %s AND task_type = %s
|
|
|
|
|
ORDER BY sort_order ASC, config_id ASC
|
|
|
|
|
""",
|
|
|
|
|
(current_user["user_id"], task_type),
|
|
|
|
|
)
|
|
|
|
|
configs = cursor.fetchall()
|
|
|
|
|
|
|
|
|
|
selected_prompt_ids = [item["prompt_id"] for item in configs if item.get("is_enabled") == 1]
|
|
|
|
|
return create_api_response(
|
|
|
|
|
code="200",
|
|
|
|
|
message="获取提示词配置成功",
|
|
|
|
|
data={
|
|
|
|
|
"task_type": task_type,
|
|
|
|
|
"available_prompts": available,
|
|
|
|
|
"configs": configs,
|
|
|
|
|
"selected_prompt_ids": selected_prompt_ids,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.put("/prompts/config/{task_type}")
|
|
|
|
|
def update_prompt_config(
|
|
|
|
|
task_type: str,
|
|
|
|
|
request: PromptConfigUpdateRequest,
|
|
|
|
|
current_user: dict = Depends(get_current_user),
|
|
|
|
|
):
|
|
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
|
try:
|
|
|
|
|
requested_ids = [int(item.prompt_id) for item in request.items if item.is_enabled]
|
|
|
|
|
if requested_ids:
|
|
|
|
|
placeholders = ",".join(["%s"] * len(requested_ids))
|
2026-01-19 11:03:08 +00:00
|
|
|
cursor.execute(
|
2026-03-26 06:55:12 +00:00
|
|
|
f"""
|
|
|
|
|
SELECT id
|
|
|
|
|
FROM prompts
|
|
|
|
|
WHERE id IN ({placeholders})
|
|
|
|
|
AND task_type = %s
|
|
|
|
|
AND is_active = 1
|
|
|
|
|
AND (is_system = 1 OR creator_id = %s)
|
|
|
|
|
""",
|
|
|
|
|
tuple(requested_ids + [task_type, current_user["user_id"]]),
|
2026-01-19 11:03:08 +00:00
|
|
|
)
|
2026-03-26 06:55:12 +00:00
|
|
|
valid_ids = {row["id"] for row in cursor.fetchall()}
|
|
|
|
|
invalid_ids = [pid for pid in requested_ids if pid not in valid_ids]
|
|
|
|
|
if invalid_ids:
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"存在无效提示词ID: {invalid_ids}")
|
2026-01-19 11:03:08 +00:00
|
|
|
|
|
|
|
|
cursor.execute(
|
2026-03-26 06:55:12 +00:00
|
|
|
"DELETE FROM prompt_config WHERE user_id = %s AND task_type = %s",
|
|
|
|
|
(current_user["user_id"], task_type),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
ordered = sorted(
|
|
|
|
|
[item for item in request.items if item.is_enabled],
|
|
|
|
|
key=lambda x: (x.sort_order, x.prompt_id),
|
2026-01-19 11:03:08 +00:00
|
|
|
)
|
2026-03-26 06:55:12 +00:00
|
|
|
for idx, item in enumerate(ordered):
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"""
|
|
|
|
|
INSERT INTO prompt_config (user_id, task_type, prompt_id, is_enabled, sort_order)
|
|
|
|
|
VALUES (%s, %s, %s, 1, %s)
|
|
|
|
|
""",
|
|
|
|
|
(current_user["user_id"], task_type, int(item.prompt_id), idx + 1),
|
|
|
|
|
)
|
2026-01-19 11:03:08 +00:00
|
|
|
|
|
|
|
|
connection.commit()
|
2026-03-26 06:55:12 +00:00
|
|
|
return create_api_response(code="200", message="提示词配置保存成功")
|
|
|
|
|
except HTTPException:
|
|
|
|
|
connection.rollback()
|
|
|
|
|
raise
|
2026-01-19 11:03:08 +00:00
|
|
|
except Exception as e:
|
2026-03-26 06:55:12 +00:00
|
|
|
connection.rollback()
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
2026-01-19 11:03:08 +00:00
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
|
|
|
|
|
@router.put("/prompts/{prompt_id}")
|
|
|
|
|
def update_prompt(prompt_id: int, prompt: PromptUpdate, current_user: dict = Depends(get_current_user)):
|
2026-01-19 11:03:08 +00:00
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
cursor = connection.cursor(dictionary=True)
|
2026-03-26 06:55:12 +00:00
|
|
|
try:
|
|
|
|
|
cursor.execute("SELECT id, creator_id, task_type, is_default, is_system FROM prompts WHERE id = %s", (prompt_id,))
|
|
|
|
|
existing = cursor.fetchone()
|
|
|
|
|
if not existing:
|
|
|
|
|
raise HTTPException(status_code=404, detail="模版不存在")
|
|
|
|
|
if not _can_manage_prompt(current_user, existing):
|
|
|
|
|
raise HTTPException(status_code=403, detail="无权修改此模版")
|
2026-01-19 11:03:08 +00:00
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
if prompt.is_default is False and existing["is_default"]:
|
|
|
|
|
raise HTTPException(status_code=400, detail="必须保留一个默认模版,请先设置其他模版为默认")
|
|
|
|
|
if prompt.is_system is not None and not _is_admin(current_user):
|
|
|
|
|
raise HTTPException(status_code=403, detail="普通用户不能修改系统提示词属性")
|
2026-01-19 11:03:08 +00:00
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
if prompt.is_default:
|
|
|
|
|
task_type = prompt.task_type or existing["task_type"]
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"""
|
|
|
|
|
UPDATE prompts
|
|
|
|
|
SET is_default = 0
|
|
|
|
|
WHERE task_type = %s
|
|
|
|
|
AND is_system = %s
|
|
|
|
|
AND creator_id = %s
|
|
|
|
|
""",
|
|
|
|
|
(task_type, existing.get("is_system", 0), existing["creator_id"]),
|
|
|
|
|
)
|
|
|
|
|
if prompt.is_active is False:
|
|
|
|
|
raise HTTPException(status_code=400, detail="默认模版必须处于启用状态")
|
2026-01-19 11:03:08 +00:00
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
update_fields = []
|
|
|
|
|
params = []
|
|
|
|
|
prompt_data = prompt.dict(exclude_unset=True)
|
|
|
|
|
for field, value in prompt_data.items():
|
|
|
|
|
if field == "desc":
|
|
|
|
|
update_fields.append("`desc` = %s")
|
|
|
|
|
else:
|
|
|
|
|
update_fields.append(f"{field} = %s")
|
|
|
|
|
params.append(value)
|
2026-01-19 11:03:08 +00:00
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
if update_fields:
|
|
|
|
|
params.append(prompt_id)
|
|
|
|
|
cursor.execute(f"UPDATE prompts SET {', '.join(update_fields)} WHERE id = %s", tuple(params))
|
|
|
|
|
|
|
|
|
|
connection.commit()
|
|
|
|
|
return create_api_response(code="200", message="更新成功")
|
|
|
|
|
except HTTPException:
|
|
|
|
|
raise
|
|
|
|
|
except Exception as e:
|
|
|
|
|
connection.rollback()
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/prompts/{prompt_id}")
|
|
|
|
|
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
|
|
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
cursor = connection.cursor(dictionary=True)
|
|
|
|
|
try:
|
|
|
|
|
cursor.execute("SELECT id, creator_id, is_default, is_system FROM prompts WHERE id = %s", (prompt_id,))
|
|
|
|
|
existing = cursor.fetchone()
|
|
|
|
|
if not existing:
|
|
|
|
|
raise HTTPException(status_code=404, detail="模版不存在")
|
|
|
|
|
if not _can_manage_prompt(current_user, existing):
|
|
|
|
|
raise HTTPException(status_code=403, detail="无权删除此模版")
|
|
|
|
|
if existing["is_default"]:
|
|
|
|
|
raise HTTPException(status_code=400, detail="默认模版不允许删除,请先设置其他模版为默认")
|
2026-01-19 11:03:08 +00:00
|
|
|
|
2026-03-26 06:55:12 +00:00
|
|
|
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
|
|
|
|
|
connection.commit()
|
|
|
|
|
return create_api_response(code="200", message="删除成功")
|
|
|
|
|
except HTTPException:
|
|
|
|
|
raise
|
|
|
|
|
except Exception as e:
|
|
|
|
|
connection.rollback()
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|