imetting/backend/app/api/endpoints/prompts.py

384 lines
15 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from typing import Optional, List
from app.core.auth import get_current_user
from app.core.database import get_db_connection
from app.core.response import create_api_response
from app.models.models import PromptCreate, PromptUpdate
router = APIRouter()
class PromptConfigItem(BaseModel):
prompt_id: int
is_enabled: bool = True
sort_order: int = 0
class PromptConfigUpdateRequest(BaseModel):
items: List[PromptConfigItem]
def _is_admin(user: dict) -> bool:
return int(user.get("role_id") or 0) == 1
def _can_manage_prompt(current_user: dict, row: dict) -> bool:
if _is_admin(current_user):
return True
return int(row.get("creator_id") or 0) == int(current_user["user_id"]) and int(row.get("is_system") or 0) == 0
@router.post("/prompts")
def create_prompt(
prompt: PromptCreate,
current_user: dict = Depends(get_current_user),
):
"""Create a prompt template. Admin can create system prompts, others can only create personal prompts."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
try:
is_admin = _is_admin(current_user)
requested_is_system = bool(getattr(prompt, "is_system", False))
is_system = 1 if (is_admin and requested_is_system) else 0
owner_user_id = current_user["user_id"]
cursor.execute(
"""
SELECT COUNT(*) as cnt
FROM prompts
WHERE task_type = %s
AND is_system = %s
AND creator_id = %s
""",
(prompt.task_type, is_system, owner_user_id),
)
count = (cursor.fetchone() or {}).get("cnt", 0)
is_default = 1 if count == 0 else (1 if prompt.is_default else 0)
if is_default:
cursor.execute(
"""
UPDATE prompts
SET is_default = 0
WHERE task_type = %s
AND is_system = %s
AND creator_id = %s
""",
(prompt.task_type, is_system, owner_user_id),
)
cursor.execute(
"""
INSERT INTO prompts (name, task_type, content, `desc`, is_default, is_active, creator_id, is_system)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
""",
(
prompt.name,
prompt.task_type,
prompt.content,
prompt.desc,
is_default,
1 if prompt.is_active else 0,
owner_user_id,
is_system,
),
)
prompt_id = cursor.lastrowid
connection.commit()
return create_api_response(code="200", message="提示词模版创建成功", data={"id": prompt_id})
except Exception as e:
connection.rollback()
raise HTTPException(status_code=500, detail=str(e))
@router.get("/prompts")
def get_prompts(
task_type: Optional[str] = None,
page: int = 1,
size: int = 12,
keyword: Optional[str] = Query(None),
is_active: Optional[int] = Query(None),
scope: str = Query("mine"), # mine / system / all / accessible
current_user: dict = Depends(get_current_user),
):
"""Get paginated prompt cards. Normal users can only view their own prompts."""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
is_admin = _is_admin(current_user)
where_conditions = []
params = []
if scope == "all" and not is_admin:
scope = "accessible"
if scope == "system":
where_conditions.append("p.is_system = 1")
elif scope == "all":
where_conditions.append("(p.is_system = 1 OR p.creator_id = %s)")
params.append(current_user["user_id"])
elif scope == "accessible":
where_conditions.append("((p.is_system = 1 AND p.is_active = 1) OR (p.is_system = 0 AND p.creator_id = %s))")
params.append(current_user["user_id"])
else:
where_conditions.append("p.is_system = 0 AND p.creator_id = %s")
params.append(current_user["user_id"])
if task_type:
where_conditions.append("p.task_type = %s")
params.append(task_type)
if keyword:
where_conditions.append("(p.name LIKE %s OR p.`desc` LIKE %s)")
like = f"%{keyword}%"
params.extend([like, like])
if is_active in (0, 1):
where_conditions.append("p.is_active = %s")
params.append(is_active)
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
cursor.execute(f"SELECT COUNT(*) as total FROM prompts p WHERE {where_clause}", tuple(params))
total = (cursor.fetchone() or {}).get("total", 0)
offset = max(page - 1, 0) * size
cursor.execute(
f"""
SELECT p.id, p.name, p.task_type, p.content, p.`desc`, p.is_default, p.is_active,
p.creator_id, p.is_system, p.created_at,
u.caption AS creator_name
FROM prompts p
LEFT JOIN sys_users u ON u.user_id = p.creator_id
WHERE {where_clause}
ORDER BY p.is_system DESC, p.task_type ASC, p.is_default DESC, p.created_at DESC
LIMIT %s OFFSET %s
""",
tuple(params + [size, offset]),
)
rows = cursor.fetchall()
return create_api_response(
code="200",
message="获取提示词列表成功",
data={"prompts": rows, "total": total, "page": page, "size": size},
)
@router.get("/prompts/active/{task_type}")
def get_active_prompts(task_type: str, current_user: dict = Depends(get_current_user)):
"""
Active prompts for task selection.
Includes system prompts + personal prompts, and applies user's prompt config ordering.
"""
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
cursor.execute(
"""
SELECT p.id, p.name, p.`desc`, p.content, p.is_default, p.is_system, p.creator_id,
cfg.is_enabled, cfg.sort_order
FROM prompts p
LEFT JOIN prompt_config cfg
ON cfg.prompt_id = p.id
AND cfg.user_id = %s
AND cfg.task_type = %s
WHERE p.task_type = %s
AND p.is_active = 1
AND (p.is_system = 1 OR p.creator_id = %s)
ORDER BY
CASE WHEN cfg.is_enabled = 1 THEN 0 ELSE 1 END,
cfg.sort_order ASC,
p.is_default DESC,
p.created_at DESC
""",
(current_user["user_id"], task_type, task_type, current_user["user_id"]),
)
prompts = cursor.fetchall()
enabled = [x for x in prompts if x.get("is_enabled") == 1]
if enabled:
result = enabled
else:
result = prompts
return create_api_response(code="200", message="获取启用模版列表成功", data={"prompts": result})
@router.get("/prompts/config/{task_type}")
def get_prompt_config(task_type: str, current_user: dict = Depends(get_current_user)):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
cursor.execute(
"""
SELECT id, name, task_type, content, `desc`, is_default, is_active, is_system, creator_id, created_at
FROM prompts
WHERE task_type = %s
AND is_active = 1
AND (is_system = 1 OR creator_id = %s)
ORDER BY is_system DESC, is_default DESC, created_at DESC
""",
(task_type, current_user["user_id"]),
)
available = cursor.fetchall()
cursor.execute(
"""
SELECT prompt_id, is_enabled, sort_order
FROM prompt_config
WHERE user_id = %s AND task_type = %s
ORDER BY sort_order ASC, config_id ASC
""",
(current_user["user_id"], task_type),
)
configs = cursor.fetchall()
selected_prompt_ids = [item["prompt_id"] for item in configs if item.get("is_enabled") == 1]
return create_api_response(
code="200",
message="获取提示词配置成功",
data={
"task_type": task_type,
"available_prompts": available,
"configs": configs,
"selected_prompt_ids": selected_prompt_ids,
},
)
@router.put("/prompts/config/{task_type}")
def update_prompt_config(
task_type: str,
request: PromptConfigUpdateRequest,
current_user: dict = Depends(get_current_user),
):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
try:
requested_ids = [int(item.prompt_id) for item in request.items if item.is_enabled]
if requested_ids:
placeholders = ",".join(["%s"] * len(requested_ids))
cursor.execute(
f"""
SELECT id
FROM prompts
WHERE id IN ({placeholders})
AND task_type = %s
AND is_active = 1
AND (is_system = 1 OR creator_id = %s)
""",
tuple(requested_ids + [task_type, current_user["user_id"]]),
)
valid_ids = {row["id"] for row in cursor.fetchall()}
invalid_ids = [pid for pid in requested_ids if pid not in valid_ids]
if invalid_ids:
raise HTTPException(status_code=400, detail=f"存在无效提示词ID: {invalid_ids}")
cursor.execute(
"DELETE FROM prompt_config WHERE user_id = %s AND task_type = %s",
(current_user["user_id"], task_type),
)
ordered = sorted(
[item for item in request.items if item.is_enabled],
key=lambda x: (x.sort_order, x.prompt_id),
)
for idx, item in enumerate(ordered):
cursor.execute(
"""
INSERT INTO prompt_config (user_id, task_type, prompt_id, is_enabled, sort_order)
VALUES (%s, %s, %s, 1, %s)
""",
(current_user["user_id"], task_type, int(item.prompt_id), idx + 1),
)
connection.commit()
return create_api_response(code="200", message="提示词配置保存成功")
except HTTPException:
connection.rollback()
raise
except Exception as e:
connection.rollback()
raise HTTPException(status_code=500, detail=str(e))
@router.put("/prompts/{prompt_id}")
def update_prompt(prompt_id: int, prompt: PromptUpdate, current_user: dict = Depends(get_current_user)):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
try:
cursor.execute("SELECT id, creator_id, task_type, is_default, is_system FROM prompts WHERE id = %s", (prompt_id,))
existing = cursor.fetchone()
if not existing:
raise HTTPException(status_code=404, detail="模版不存在")
if not _can_manage_prompt(current_user, existing):
raise HTTPException(status_code=403, detail="无权修改此模版")
if prompt.is_default is False and existing["is_default"]:
raise HTTPException(status_code=400, detail="必须保留一个默认模版,请先设置其他模版为默认")
if prompt.is_system is not None and not _is_admin(current_user):
raise HTTPException(status_code=403, detail="普通用户不能修改系统提示词属性")
if prompt.is_default:
task_type = prompt.task_type or existing["task_type"]
cursor.execute(
"""
UPDATE prompts
SET is_default = 0
WHERE task_type = %s
AND is_system = %s
AND creator_id = %s
""",
(task_type, existing.get("is_system", 0), existing["creator_id"]),
)
if prompt.is_active is False:
raise HTTPException(status_code=400, detail="默认模版必须处于启用状态")
update_fields = []
params = []
prompt_data = prompt.dict(exclude_unset=True)
for field, value in prompt_data.items():
if field == "desc":
update_fields.append("`desc` = %s")
else:
update_fields.append(f"{field} = %s")
params.append(value)
if update_fields:
params.append(prompt_id)
cursor.execute(f"UPDATE prompts SET {', '.join(update_fields)} WHERE id = %s", tuple(params))
connection.commit()
return create_api_response(code="200", message="更新成功")
except HTTPException:
raise
except Exception as e:
connection.rollback()
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/prompts/{prompt_id}")
def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)):
with get_db_connection() as connection:
cursor = connection.cursor(dictionary=True)
try:
cursor.execute("SELECT id, creator_id, is_default, is_system FROM prompts WHERE id = %s", (prompt_id,))
existing = cursor.fetchone()
if not existing:
raise HTTPException(status_code=404, detail="模版不存在")
if not _can_manage_prompt(current_user, existing):
raise HTTPException(status_code=403, detail="无权删除此模版")
if existing["is_default"]:
raise HTTPException(status_code=400, detail="默认模版不允许删除,请先设置其他模版为默认")
cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,))
connection.commit()
return create_api_response(code="200", message="删除成功")
except HTTPException:
raise
except Exception as e:
connection.rollback()
raise HTTPException(status_code=500, detail=str(e))