from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel from typing import Optional, List from app.core.auth import get_current_user from app.core.database import get_db_connection from app.core.response import create_api_response from app.models.models import PromptCreate, PromptUpdate router = APIRouter() class PromptConfigItem(BaseModel): prompt_id: int is_enabled: bool = True sort_order: int = 0 class PromptConfigUpdateRequest(BaseModel): items: List[PromptConfigItem] def _is_admin(user: dict) -> bool: return int(user.get("role_id") or 0) == 1 def _can_manage_prompt(current_user: dict, row: dict) -> bool: if _is_admin(current_user): return True return int(row.get("creator_id") or 0) == int(current_user["user_id"]) and int(row.get("is_system") or 0) == 0 @router.post("/prompts") def create_prompt( prompt: PromptCreate, current_user: dict = Depends(get_current_user), ): """Create a prompt template. Admin can create system prompts, others can only create personal prompts.""" with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) try: is_admin = _is_admin(current_user) requested_is_system = bool(getattr(prompt, "is_system", False)) is_system = 1 if (is_admin and requested_is_system) else 0 owner_user_id = current_user["user_id"] cursor.execute( """ SELECT COUNT(*) as cnt FROM prompts WHERE task_type = %s AND is_system = %s AND creator_id = %s """, (prompt.task_type, is_system, owner_user_id), ) count = (cursor.fetchone() or {}).get("cnt", 0) is_default = 1 if count == 0 else (1 if prompt.is_default else 0) if is_default: cursor.execute( """ UPDATE prompts SET is_default = 0 WHERE task_type = %s AND is_system = %s AND creator_id = %s """, (prompt.task_type, is_system, owner_user_id), ) cursor.execute( """ INSERT INTO prompts (name, task_type, content, `desc`, is_default, is_active, creator_id, is_system) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) """, ( prompt.name, prompt.task_type, prompt.content, prompt.desc, is_default, 1 if prompt.is_active else 0, owner_user_id, is_system, ), ) prompt_id = cursor.lastrowid connection.commit() return create_api_response(code="200", message="提示词模版创建成功", data={"id": prompt_id}) except Exception as e: connection.rollback() raise HTTPException(status_code=500, detail=str(e)) @router.get("/prompts") def get_prompts( task_type: Optional[str] = None, page: int = 1, size: int = 12, keyword: Optional[str] = Query(None), is_active: Optional[int] = Query(None), scope: str = Query("mine"), # mine / system / all / accessible current_user: dict = Depends(get_current_user), ): """Get paginated prompt cards. Normal users can only view their own prompts.""" with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) is_admin = _is_admin(current_user) where_conditions = [] params = [] if scope == "all" and not is_admin: scope = "accessible" if scope == "system": where_conditions.append("p.is_system = 1") elif scope == "all": where_conditions.append("(p.is_system = 1 OR p.creator_id = %s)") params.append(current_user["user_id"]) elif scope == "accessible": where_conditions.append("((p.is_system = 1 AND p.is_active = 1) OR (p.is_system = 0 AND p.creator_id = %s))") params.append(current_user["user_id"]) else: where_conditions.append("p.is_system = 0 AND p.creator_id = %s") params.append(current_user["user_id"]) if task_type: where_conditions.append("p.task_type = %s") params.append(task_type) if keyword: where_conditions.append("(p.name LIKE %s OR p.`desc` LIKE %s)") like = f"%{keyword}%" params.extend([like, like]) if is_active in (0, 1): where_conditions.append("p.is_active = %s") params.append(is_active) where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" cursor.execute(f"SELECT COUNT(*) as total FROM prompts p WHERE {where_clause}", tuple(params)) total = (cursor.fetchone() or {}).get("total", 0) offset = max(page - 1, 0) * size cursor.execute( f""" SELECT p.id, p.name, p.task_type, p.content, p.`desc`, p.is_default, p.is_active, p.creator_id, p.is_system, p.created_at, u.caption AS creator_name FROM prompts p LEFT JOIN sys_users u ON u.user_id = p.creator_id WHERE {where_clause} ORDER BY p.is_system DESC, p.task_type ASC, p.is_default DESC, p.created_at DESC LIMIT %s OFFSET %s """, tuple(params + [size, offset]), ) rows = cursor.fetchall() return create_api_response( code="200", message="获取提示词列表成功", data={"prompts": rows, "total": total, "page": page, "size": size}, ) @router.get("/prompts/active/{task_type}") def get_active_prompts(task_type: str, current_user: dict = Depends(get_current_user)): """ Active prompts for task selection. Includes system prompts + personal prompts, and applies user's prompt config ordering. """ with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) cursor.execute( """ SELECT p.id, p.name, p.`desc`, p.content, p.is_default, p.is_system, p.creator_id, cfg.is_enabled, cfg.sort_order FROM prompts p LEFT JOIN prompt_config cfg ON cfg.prompt_id = p.id AND cfg.user_id = %s AND cfg.task_type = %s WHERE p.task_type = %s AND p.is_active = 1 AND (p.is_system = 1 OR p.creator_id = %s) ORDER BY CASE WHEN cfg.is_enabled = 1 THEN 0 ELSE 1 END, cfg.sort_order ASC, p.is_default DESC, p.created_at DESC """, (current_user["user_id"], task_type, task_type, current_user["user_id"]), ) prompts = cursor.fetchall() enabled = [x for x in prompts if x.get("is_enabled") == 1] if enabled: result = enabled else: result = prompts return create_api_response(code="200", message="获取启用模版列表成功", data={"prompts": result}) @router.get("/prompts/config/{task_type}") def get_prompt_config(task_type: str, current_user: dict = Depends(get_current_user)): with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) cursor.execute( """ SELECT id, name, task_type, content, `desc`, is_default, is_active, is_system, creator_id, created_at FROM prompts WHERE task_type = %s AND is_active = 1 AND (is_system = 1 OR creator_id = %s) ORDER BY is_system DESC, is_default DESC, created_at DESC """, (task_type, current_user["user_id"]), ) available = cursor.fetchall() cursor.execute( """ SELECT prompt_id, is_enabled, sort_order FROM prompt_config WHERE user_id = %s AND task_type = %s ORDER BY sort_order ASC, config_id ASC """, (current_user["user_id"], task_type), ) configs = cursor.fetchall() selected_prompt_ids = [item["prompt_id"] for item in configs if item.get("is_enabled") == 1] return create_api_response( code="200", message="获取提示词配置成功", data={ "task_type": task_type, "available_prompts": available, "configs": configs, "selected_prompt_ids": selected_prompt_ids, }, ) @router.put("/prompts/config/{task_type}") def update_prompt_config( task_type: str, request: PromptConfigUpdateRequest, current_user: dict = Depends(get_current_user), ): with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) try: requested_ids = [int(item.prompt_id) for item in request.items if item.is_enabled] if requested_ids: placeholders = ",".join(["%s"] * len(requested_ids)) cursor.execute( f""" SELECT id FROM prompts WHERE id IN ({placeholders}) AND task_type = %s AND is_active = 1 AND (is_system = 1 OR creator_id = %s) """, tuple(requested_ids + [task_type, current_user["user_id"]]), ) valid_ids = {row["id"] for row in cursor.fetchall()} invalid_ids = [pid for pid in requested_ids if pid not in valid_ids] if invalid_ids: raise HTTPException(status_code=400, detail=f"存在无效提示词ID: {invalid_ids}") cursor.execute( "DELETE FROM prompt_config WHERE user_id = %s AND task_type = %s", (current_user["user_id"], task_type), ) ordered = sorted( [item for item in request.items if item.is_enabled], key=lambda x: (x.sort_order, x.prompt_id), ) for idx, item in enumerate(ordered): cursor.execute( """ INSERT INTO prompt_config (user_id, task_type, prompt_id, is_enabled, sort_order) VALUES (%s, %s, %s, 1, %s) """, (current_user["user_id"], task_type, int(item.prompt_id), idx + 1), ) connection.commit() return create_api_response(code="200", message="提示词配置保存成功") except HTTPException: connection.rollback() raise except Exception as e: connection.rollback() raise HTTPException(status_code=500, detail=str(e)) @router.put("/prompts/{prompt_id}") def update_prompt(prompt_id: int, prompt: PromptUpdate, current_user: dict = Depends(get_current_user)): with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) try: cursor.execute("SELECT id, creator_id, task_type, is_default, is_system FROM prompts WHERE id = %s", (prompt_id,)) existing = cursor.fetchone() if not existing: raise HTTPException(status_code=404, detail="模版不存在") if not _can_manage_prompt(current_user, existing): raise HTTPException(status_code=403, detail="无权修改此模版") if prompt.is_default is False and existing["is_default"]: raise HTTPException(status_code=400, detail="必须保留一个默认模版,请先设置其他模版为默认") if prompt.is_system is not None and not _is_admin(current_user): raise HTTPException(status_code=403, detail="普通用户不能修改系统提示词属性") if prompt.is_default: task_type = prompt.task_type or existing["task_type"] cursor.execute( """ UPDATE prompts SET is_default = 0 WHERE task_type = %s AND is_system = %s AND creator_id = %s """, (task_type, existing.get("is_system", 0), existing["creator_id"]), ) if prompt.is_active is False: raise HTTPException(status_code=400, detail="默认模版必须处于启用状态") update_fields = [] params = [] prompt_data = prompt.dict(exclude_unset=True) for field, value in prompt_data.items(): if field == "desc": update_fields.append("`desc` = %s") else: update_fields.append(f"{field} = %s") params.append(value) if update_fields: params.append(prompt_id) cursor.execute(f"UPDATE prompts SET {', '.join(update_fields)} WHERE id = %s", tuple(params)) connection.commit() return create_api_response(code="200", message="更新成功") except HTTPException: raise except Exception as e: connection.rollback() raise HTTPException(status_code=500, detail=str(e)) @router.delete("/prompts/{prompt_id}") def delete_prompt(prompt_id: int, current_user: dict = Depends(get_current_user)): with get_db_connection() as connection: cursor = connection.cursor(dictionary=True) try: cursor.execute("SELECT id, creator_id, is_default, is_system FROM prompts WHERE id = %s", (prompt_id,)) existing = cursor.fetchone() if not existing: raise HTTPException(status_code=404, detail="模版不存在") if not _can_manage_prompt(current_user, existing): raise HTTPException(status_code=403, detail="无权删除此模版") if existing["is_default"]: raise HTTPException(status_code=400, detail="默认模版不允许删除,请先设置其他模版为默认") cursor.execute("DELETE FROM prompts WHERE id = %s", (prompt_id,)) connection.commit() return create_api_response(code="200", message="删除成功") except HTTPException: raise except Exception as e: connection.rollback() raise HTTPException(status_code=500, detail=str(e))