nex_docus/backend/app/api/v1/search.py

270 lines
11 KiB
Python
Raw Normal View History

2025-12-29 12:53:50 +00:00
"""
文档搜索相关 API
"""
2026-01-23 07:00:03 +00:00
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
2025-12-29 12:53:50 +00:00
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_
2026-01-23 07:00:03 +00:00
from typing import Optional, List
from pathlib import Path
import logging
2025-12-29 12:53:50 +00:00
from app.core.database import get_db
from app.core.deps import get_current_user
from app.models.user import User
from app.models.project import Project, ProjectMember
2026-01-23 07:00:03 +00:00
from app.services.search_service import search_service
2025-12-29 12:53:50 +00:00
from app.services.storage import storage_service
from app.schemas.response import success_response
router = APIRouter()
2026-01-23 07:00:03 +00:00
logger = logging.getLogger(__name__)
2025-12-29 12:53:50 +00:00
@router.get("/documents", response_model=dict)
async def search_documents(
keyword: str = Query(..., min_length=1, description="搜索关键词"),
2026-01-23 07:00:03 +00:00
project_id: Optional[int] = Query(None, description="限制在指定项目中搜索"),
2025-12-29 12:53:50 +00:00
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
2026-01-23 07:00:03 +00:00
文档搜索 (混合模式Whoosh 全文检索 + 数据库项目搜索 + 文件系统文件名搜索 fallback)
2025-12-29 12:53:50 +00:00
"""
2026-01-23 07:00:03 +00:00
try:
if not keyword:
return success_response(data=[])
# 1. 确定搜索范围 (项目ID列表)
allowed_project_ids = []
if project_id:
# 检查指定项目的访问权限
result = await db.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail="项目不存在")
# 检查权限
if project.owner_id != current_user.id and project.is_public != 1:
member_result = await db.execute(
select(ProjectMember).where(
ProjectMember.project_id == project_id,
ProjectMember.user_id == current_user.id
)
)
if not member_result.scalar_one_or_none():
raise HTTPException(status_code=403, detail="无权访问该项目")
allowed_project_ids.append(str(project_id))
else:
# 获取所有可访问的项目
# 1. 用户创建的项目
owned_result = await db.execute(
select(Project.id).where(Project.owner_id == current_user.id, Project.status == 1)
)
allowed_project_ids.extend([str(pid) for pid in owned_result.scalars().all()])
# 2. 用户参与的项目
member_result = await db.execute(
select(ProjectMember.project_id)
.join(Project, ProjectMember.project_id == Project.id)
.where(
ProjectMember.user_id == current_user.id,
Project.status == 1
)
)
allowed_project_ids.extend([str(pid) for pid in member_result.scalars().all()])
# 去重
allowed_project_ids = list(set(allowed_project_ids))
if not allowed_project_ids:
return success_response(data=[])
2025-12-29 12:53:50 +00:00
2026-01-23 07:00:03 +00:00
# 2. 执行搜索
search_results = []
# A. 数据库项目搜索 (仅当未指定 project_id 时,或者需要搜项目本身)
# 如果前端指定了 project_id通常是在项目内搜文件不需要搜项目本身
if not project_id:
projects_query = select(Project).where(
Project.id.in_(allowed_project_ids),
or_(
Project.name.ilike(f"%{keyword}%"),
Project.description.ilike(f"%{keyword}%")
)
)
project_res = await db.execute(projects_query)
matched_projects = project_res.scalars().all()
for proj in matched_projects:
search_results.append({
"type": "project",
"project_id": proj.id,
"project_name": proj.name,
"project_description": proj.description or "",
"match_type": "项目名称/描述",
})
# B. Whoosh 全文检索
whoosh_results = []
2025-12-29 12:53:50 +00:00
try:
2026-01-23 07:00:03 +00:00
if project_id:
whoosh_results = await search_service.search(keyword, str(project_id))
else:
# 全局搜索
whoosh_results = await search_service.search(keyword, limit=50)
# 过滤权限
whoosh_results = [r for r in whoosh_results if str(r['project_id']) in allowed_project_ids]
except Exception as e:
logger.warning(f"Whoosh search failed: {e}")
pass
# 获取 Whoosh 结果涉及的项目 ID
whoosh_project_ids = set(res['project_id'] for res in whoosh_results if res.get('project_id'))
# 查询项目名称映射
project_name_map = {}
if whoosh_project_ids:
p_res = await db.execute(select(Project.id, Project.name).where(Project.id.in_(whoosh_project_ids)))
for pid, pname in p_res.all():
project_name_map[str(pid)] = pname
2025-12-29 12:53:50 +00:00
2026-01-23 07:00:03 +00:00
# 添加 Whoosh 结果
for res in whoosh_results:
pid_str = str(res['project_id'])
search_results.append({
"type": "file",
"project_id": res['project_id'],
"project_name": project_name_map.get(pid_str, "未知项目"),
"file_path": res['path'],
"file_name": res['title'],
"highlights": res.get('highlights'),
"match_type": "全文检索"
})
2025-12-29 12:53:50 +00:00
2026-01-23 07:00:03 +00:00
# C. 文件系统文件名搜索 (Fallback / Complementary)
# 为了保证未索引的文件也能通过文件名搜到
# 获取需要扫描的项目
projects_to_scan = []
if project_id:
# 单项目扫描
res = await db.execute(select(Project).where(Project.id == project_id))
p = res.scalar_one_or_none()
if p: projects_to_scan = [p]
elif len(search_results) < 20:
# 全局扫描:仅当结果较少时才进行全盘扫描,避免性能问题
# 这是一个简单的启发式策略
res = await db.execute(select(Project).where(Project.id.in_(allowed_project_ids)))
projects_to_scan = res.scalars().all()
2025-12-29 12:53:50 +00:00
2026-01-23 07:00:03 +00:00
# 已存在的文件路径集合 (用于去重)
existing_paths = set()
for res in search_results:
if res.get('type') == 'file':
# 统一 key 格式
existing_paths.add(f"{res['project_id']}:{res['file_path']}")
2025-12-29 12:53:50 +00:00
2026-01-23 07:00:03 +00:00
keyword_lower = keyword.lower()
for project in projects_to_scan:
try:
project_path = storage_service.get_secure_path(project.storage_key)
if not project_path.exists(): continue
# 查找文件名匹配
md_files = list(project_path.rglob("*.md"))
pdf_files = list(project_path.rglob("*.pdf"))
for file_path in md_files + pdf_files:
if "_assets" in file_path.parts: continue
if keyword_lower in file_path.name.lower():
rel_path = str(file_path.relative_to(project_path))
unique_key = f"{project.id}:{rel_path}"
if unique_key not in existing_paths:
search_results.append({
"type": "file",
"project_id": project.id,
"project_name": project.name,
"file_path": rel_path,
"file_name": file_path.name,
"match_type": "文件名匹配"
})
existing_paths.add(unique_key)
except Exception:
continue
2025-12-29 12:53:50 +00:00
2026-01-23 07:00:03 +00:00
return success_response(data=search_results[:100])
except Exception as e:
logger.error(f"Search API error: {e}")
return success_response(data=[], message="搜索服务暂时不可用")
2025-12-29 12:53:50 +00:00
2026-01-23 07:00:03 +00:00
async def rebuild_index_task(db: AsyncSession):
"""后台任务:重建索引"""
logger.info("Starting index rebuild...")
try:
# 获取所有项目
result = await db.execute(select(Project).where(Project.status == 1))
projects = result.scalars().all()
documents = []
for project in projects:
try:
# 遍历项目文件
project_root = storage_service.get_secure_path(project.storage_key)
if not project_root.exists():
2025-12-29 12:53:50 +00:00
continue
2026-01-23 07:00:03 +00:00
# 查找所有 .md 文件
md_files = list(project_root.rglob("*.md"))
for file_path in md_files:
if "_assets" in file_path.parts:
continue
try:
content = await storage_service.read_file(file_path)
relative_path = str(file_path.relative_to(project_root))
documents.append({
"project_id": project.id,
"path": relative_path,
"title": file_path.stem,
"content": content
})
except Exception:
continue
except Exception as e:
logger.error(f"Error processing project {project.id}: {e}")
continue
# 批量写入索引
import asyncio
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, search_service.rebuild_index_sync, documents)
logger.info(f"Index rebuild completed. Indexed {len(documents)} documents.")
except Exception as e:
logger.error(f"Index rebuild failed: {e}")
2025-12-29 12:53:50 +00:00
2026-01-23 07:00:03 +00:00
@router.post("/rebuild-index", response_model=dict)
async def rebuild_index(
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
重建搜索索引 (仅限超级管理员)
"""
if not current_user.is_superuser:
raise HTTPException(status_code=403, detail="权限不足")
background_tasks.add_task(rebuild_index_task, db)
return success_response(message="索引重建任务已启动")