259 lines
10 KiB
Python
259 lines
10 KiB
Python
|
|
import hashlib
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
import os
|
||
|
|
import re
|
||
|
|
from typing import List, Optional
|
||
|
|
|
||
|
|
from openai import OpenAI as OpenAI_Client
|
||
|
|
from llama_index.core import (
|
||
|
|
Document,
|
||
|
|
VectorStoreIndex,
|
||
|
|
StorageContext,
|
||
|
|
load_index_from_storage,
|
||
|
|
)
|
||
|
|
from llama_index.core.embeddings import BaseEmbedding
|
||
|
|
from llama_index.core.settings import Settings
|
||
|
|
|
||
|
|
from config import config
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class CustomOpenAIEmbedding(BaseEmbedding):
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
model: str = "text-embedding-ada-002",
|
||
|
|
api_key: Optional[str] = None,
|
||
|
|
api_base: Optional[str] = None,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
super().__init__(model_name=model, **kwargs)
|
||
|
|
self._client = OpenAI_Client(
|
||
|
|
api_key=api_key or "not-needed",
|
||
|
|
base_url=api_base,
|
||
|
|
)
|
||
|
|
self._model = model
|
||
|
|
|
||
|
|
async def _aget_query_embedding(self, query: str) -> List[float]:
|
||
|
|
return self._get_embedding(query)
|
||
|
|
|
||
|
|
async def _aget_text_embedding(self, text: str) -> List[float]:
|
||
|
|
return self._get_embedding(text)
|
||
|
|
|
||
|
|
def _get_query_embedding(self, query: str) -> List[float]:
|
||
|
|
return self._get_embedding(query)
|
||
|
|
|
||
|
|
def _get_text_embedding(self, text: str) -> List[float]:
|
||
|
|
return self._get_embedding(text)
|
||
|
|
|
||
|
|
def _get_embedding(self, text: str) -> List[float]:
|
||
|
|
resp = self._client.embeddings.create(
|
||
|
|
model=self._model,
|
||
|
|
input=text,
|
||
|
|
)
|
||
|
|
return resp.data[0].embedding
|
||
|
|
|
||
|
|
|
||
|
|
class MeetingVectorStore:
|
||
|
|
def __init__(self):
|
||
|
|
embed_model = CustomOpenAIEmbedding(
|
||
|
|
model=config.embedding.model,
|
||
|
|
api_key=config.embedding.api_key or None,
|
||
|
|
api_base=config.embedding.api_base if config.embedding.api_base else None,
|
||
|
|
)
|
||
|
|
Settings.embed_model = embed_model
|
||
|
|
|
||
|
|
self.persist_dir = config.vector_store.persist_dir
|
||
|
|
self._index: Optional[VectorStoreIndex] = None
|
||
|
|
self._load_or_create_index()
|
||
|
|
|
||
|
|
def _load_or_create_index(self):
|
||
|
|
if os.path.exists(os.path.join(self.persist_dir, "docstore.json")):
|
||
|
|
try:
|
||
|
|
storage_context = StorageContext.from_defaults(persist_dir=self.persist_dir)
|
||
|
|
self._index = load_index_from_storage(storage_context)
|
||
|
|
logger.info(f"从磁盘加载向量索引: {self.persist_dir}")
|
||
|
|
return
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"加载向量索引失败,将创建新索引: {e}")
|
||
|
|
|
||
|
|
self._index = VectorStoreIndex.from_documents([])
|
||
|
|
logger.info("创建新的向量索引")
|
||
|
|
|
||
|
|
def _save(self):
|
||
|
|
if self._index:
|
||
|
|
os.makedirs(self.persist_dir, exist_ok=True)
|
||
|
|
self._index.storage_context.persist(persist_dir=self.persist_dir)
|
||
|
|
|
||
|
|
def _meeting_id(self, meeting_data: dict) -> str:
|
||
|
|
title = meeting_data.get("title", "")
|
||
|
|
date = meeting_data.get("date", "")
|
||
|
|
raw = f"{date}_{title}"
|
||
|
|
return f"meeting_{hashlib.md5(raw.encode('utf-8')).hexdigest()[:12]}"
|
||
|
|
|
||
|
|
def find_meeting(self, title: str, date: str = "") -> Optional[dict]:
|
||
|
|
if not self._index:
|
||
|
|
return None
|
||
|
|
query_text = f"会议标题: {title}"
|
||
|
|
if date:
|
||
|
|
query_text += f" 日期: {date}"
|
||
|
|
try:
|
||
|
|
results = self.query(query_text, top_k=3)
|
||
|
|
for r in results:
|
||
|
|
meta = r.get("metadata", {})
|
||
|
|
meta_title = meta.get("title", "")
|
||
|
|
if meta_title == title or (date and meta.get("date") == date):
|
||
|
|
return meta
|
||
|
|
return None
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"会议查重查询失败: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
def find_similar_text(self, text: str, threshold: float = 0.92) -> Optional[dict]:
|
||
|
|
if not self._index:
|
||
|
|
return None
|
||
|
|
try:
|
||
|
|
retriever = self._index.as_retriever(similarity_top_k=3)
|
||
|
|
nodes = retriever.retrieve(text)
|
||
|
|
for node in nodes:
|
||
|
|
if node.score is not None and node.score > threshold:
|
||
|
|
return {
|
||
|
|
"metadata": node.metadata,
|
||
|
|
"score": node.score,
|
||
|
|
}
|
||
|
|
return None
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"文本相似度查重失败: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
def remove_meeting(self, meeting_id: str) -> bool:
|
||
|
|
if not self._index:
|
||
|
|
return False
|
||
|
|
try:
|
||
|
|
for field in self._FIELD_TYPES:
|
||
|
|
self._index.delete_ref_doc(f"{meeting_id}_{field}")
|
||
|
|
self._save()
|
||
|
|
logger.info(f"已从向量索引移除会议: {meeting_id}")
|
||
|
|
return True
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"移除向量索引失败: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
_FIELD_TYPES = ["header", "summary", "action_items", "metrics", "decisions", "relations", "entities"]
|
||
|
|
|
||
|
|
def add_meeting(self, meeting_data: dict) -> bool:
|
||
|
|
try:
|
||
|
|
meeting_id = self._meeting_id(meeting_data)
|
||
|
|
original_text_path = meeting_data.get("_original_text_path", "")
|
||
|
|
original_text = meeting_data.get("_original_text", "")
|
||
|
|
|
||
|
|
base_metadata = {
|
||
|
|
"title": meeting_data.get("title", ""),
|
||
|
|
"date": meeting_data.get("date", ""),
|
||
|
|
"participants": ", ".join(meeting_data.get("participants", [])),
|
||
|
|
"type": "meeting",
|
||
|
|
"content_hash": meeting_data.get("_content_hash", ""),
|
||
|
|
"original_text_path": original_text_path,
|
||
|
|
"original_text_excerpt": original_text[:500] if original_text else "",
|
||
|
|
"meeting_id": meeting_id,
|
||
|
|
}
|
||
|
|
|
||
|
|
docs = self._build_field_docs(meeting_data, base_metadata, meeting_id)
|
||
|
|
|
||
|
|
if self._index:
|
||
|
|
for doc in docs:
|
||
|
|
self._index.insert(doc)
|
||
|
|
self._save()
|
||
|
|
logger.info(f"会议 '{meeting_data.get('title')}' 已添加到向量索引 (id={meeting_id}, 字段数={len(docs)})")
|
||
|
|
return True
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"添加会议到向量索引失败: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
def _build_field_docs(self, data: dict, base: dict, meeting_id: str) -> List[Document]:
|
||
|
|
docs = []
|
||
|
|
|
||
|
|
header = f"# {data.get('title', '')}"
|
||
|
|
if data.get("date"):
|
||
|
|
header += f"\n日期: {data['date']}"
|
||
|
|
if data.get("participants"):
|
||
|
|
header += f"\n参会人: {', '.join(data['participants'])}"
|
||
|
|
docs.append(Document(text=header, metadata={**base, "field": "header"}, doc_id=f"{meeting_id}_header"))
|
||
|
|
|
||
|
|
if data.get("summary"):
|
||
|
|
docs.append(Document(text=data["summary"], metadata={**base, "field": "summary"}, doc_id=f"{meeting_id}_summary"))
|
||
|
|
|
||
|
|
if data.get("action_items"):
|
||
|
|
lines = []
|
||
|
|
for item in data["action_items"]:
|
||
|
|
status = item.get('status', '待办')
|
||
|
|
lines.append(f"- [{status}] {item.get('task', '')} (负责人: {item.get('assignee', '')}, 截止: {item.get('deadline', '')}, 优先级: {item.get('priority', '')})")
|
||
|
|
history = item.get("_history", [])
|
||
|
|
if len(history) > 1:
|
||
|
|
lines.append(" 演变: " + " → ".join(f"{h.get('date','')}({h.get('status','')})" for h in history))
|
||
|
|
docs.append(Document(text="\n".join(lines), metadata={**base, "field": "action_items"}, doc_id=f"{meeting_id}_action_items"))
|
||
|
|
|
||
|
|
if data.get("metrics"):
|
||
|
|
lines = []
|
||
|
|
for m in data["metrics"]:
|
||
|
|
lines.append(f"- {m.get('metric_name', '')}: {m.get('value', '')} (目标: {m.get('target', '')}, 趋势: {m.get('trend', '')})")
|
||
|
|
docs.append(Document(text="\n".join(lines), metadata={**base, "field": "metrics"}, doc_id=f"{meeting_id}_metrics"))
|
||
|
|
|
||
|
|
if data.get("decisions"):
|
||
|
|
lines = [f"- {d.get('content', '')}" for d in data["decisions"]]
|
||
|
|
docs.append(Document(text="\n".join(lines), metadata={**base, "field": "decisions"}, doc_id=f"{meeting_id}_decisions"))
|
||
|
|
|
||
|
|
if data.get("relations"):
|
||
|
|
lines = [f"- {r.get('subject', '')} --{r.get('predicate', '')}--> {r.get('object', '')}" for r in data["relations"]]
|
||
|
|
docs.append(Document(text="\n".join(lines), metadata={**base, "field": "relations"}, doc_id=f"{meeting_id}_relations"))
|
||
|
|
|
||
|
|
if data.get("entities"):
|
||
|
|
lines = [f"- [{e.get('entity_type', '')}] {e.get('name', '')}: {e.get('description', '')}" for e in data["entities"]]
|
||
|
|
docs.append(Document(text="\n".join(lines), metadata={**base, "field": "entities"}, doc_id=f"{meeting_id}_entities"))
|
||
|
|
|
||
|
|
return docs
|
||
|
|
|
||
|
|
def query(self, question: str, top_k: int = 5) -> List[dict]:
|
||
|
|
if not self._index:
|
||
|
|
return []
|
||
|
|
try:
|
||
|
|
retriever = self._index.as_retriever(similarity_top_k=top_k)
|
||
|
|
nodes = retriever.retrieve(question)
|
||
|
|
results = []
|
||
|
|
for node in nodes:
|
||
|
|
results.append({
|
||
|
|
"text": node.text,
|
||
|
|
"score": node.score,
|
||
|
|
"metadata": node.metadata,
|
||
|
|
})
|
||
|
|
return results
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"查询向量索引失败: {e}")
|
||
|
|
return []
|
||
|
|
|
||
|
|
def query_as_context(self, question: str, top_k: int = 3) -> str:
|
||
|
|
results = self.query(question, top_k=top_k)
|
||
|
|
if not results:
|
||
|
|
return ""
|
||
|
|
parts = []
|
||
|
|
for i, r in enumerate(results):
|
||
|
|
metadata = r.get("metadata", {})
|
||
|
|
parts.append(f"[{i+1}] {metadata.get('title', '未知会议')} ({metadata.get('date', '')})\n{r['text']}\n")
|
||
|
|
return "\n".join(parts)
|
||
|
|
|
||
|
|
def get_stats(self) -> dict:
|
||
|
|
if not self._index:
|
||
|
|
return {"doc_count": 0, "node_count": 0}
|
||
|
|
try:
|
||
|
|
docstore = self._index.docstore
|
||
|
|
docs = list(docstore.docs.values()) if hasattr(docstore, 'docs') else []
|
||
|
|
return {
|
||
|
|
"doc_count": len(docstore.docs) if hasattr(docstore, 'docs') else 0,
|
||
|
|
"node_count": len(docs),
|
||
|
|
}
|
||
|
|
except Exception:
|
||
|
|
return {"doc_count": 0, "node_count": 0}
|
||
|
|
|
||
|
|
|
||
|
|
meeting_vector_store = MeetingVectorStore()
|