meeting_memory/vector_store.py

259 lines
10 KiB
Python
Raw Normal View History

2026-05-15 08:39:57 +00:00
import hashlib
import json
import logging
import os
import re
from typing import List, Optional
from openai import OpenAI as OpenAI_Client
from llama_index.core import (
Document,
VectorStoreIndex,
StorageContext,
load_index_from_storage,
)
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.settings import Settings
from config import config
logger = logging.getLogger(__name__)
class CustomOpenAIEmbedding(BaseEmbedding):
def __init__(
self,
model: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
**kwargs,
):
super().__init__(model_name=model, **kwargs)
self._client = OpenAI_Client(
api_key=api_key or "not-needed",
base_url=api_base,
)
self._model = model
async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_embedding(text)
def _get_query_embedding(self, query: str) -> List[float]:
return self._get_embedding(query)
def _get_text_embedding(self, text: str) -> List[float]:
return self._get_embedding(text)
def _get_embedding(self, text: str) -> List[float]:
resp = self._client.embeddings.create(
model=self._model,
input=text,
)
return resp.data[0].embedding
class MeetingVectorStore:
def __init__(self):
embed_model = CustomOpenAIEmbedding(
model=config.embedding.model,
api_key=config.embedding.api_key or None,
api_base=config.embedding.api_base if config.embedding.api_base else None,
)
Settings.embed_model = embed_model
self.persist_dir = config.vector_store.persist_dir
self._index: Optional[VectorStoreIndex] = None
self._load_or_create_index()
def _load_or_create_index(self):
if os.path.exists(os.path.join(self.persist_dir, "docstore.json")):
try:
storage_context = StorageContext.from_defaults(persist_dir=self.persist_dir)
self._index = load_index_from_storage(storage_context)
logger.info(f"从磁盘加载向量索引: {self.persist_dir}")
return
except Exception as e:
logger.warning(f"加载向量索引失败,将创建新索引: {e}")
self._index = VectorStoreIndex.from_documents([])
logger.info("创建新的向量索引")
def _save(self):
if self._index:
os.makedirs(self.persist_dir, exist_ok=True)
self._index.storage_context.persist(persist_dir=self.persist_dir)
def _meeting_id(self, meeting_data: dict) -> str:
title = meeting_data.get("title", "")
date = meeting_data.get("date", "")
raw = f"{date}_{title}"
return f"meeting_{hashlib.md5(raw.encode('utf-8')).hexdigest()[:12]}"
def find_meeting(self, title: str, date: str = "") -> Optional[dict]:
if not self._index:
return None
query_text = f"会议标题: {title}"
if date:
query_text += f" 日期: {date}"
try:
results = self.query(query_text, top_k=3)
for r in results:
meta = r.get("metadata", {})
meta_title = meta.get("title", "")
if meta_title == title or (date and meta.get("date") == date):
return meta
return None
except Exception as e:
logger.warning(f"会议查重查询失败: {e}")
return None
def find_similar_text(self, text: str, threshold: float = 0.92) -> Optional[dict]:
if not self._index:
return None
try:
retriever = self._index.as_retriever(similarity_top_k=3)
nodes = retriever.retrieve(text)
for node in nodes:
if node.score is not None and node.score > threshold:
return {
"metadata": node.metadata,
"score": node.score,
}
return None
except Exception as e:
logger.warning(f"文本相似度查重失败: {e}")
return None
def remove_meeting(self, meeting_id: str) -> bool:
if not self._index:
return False
try:
for field in self._FIELD_TYPES:
self._index.delete_ref_doc(f"{meeting_id}_{field}")
self._save()
logger.info(f"已从向量索引移除会议: {meeting_id}")
return True
except Exception as e:
logger.warning(f"移除向量索引失败: {e}")
return False
_FIELD_TYPES = ["header", "summary", "action_items", "metrics", "decisions", "relations", "entities"]
def add_meeting(self, meeting_data: dict) -> bool:
try:
meeting_id = self._meeting_id(meeting_data)
original_text_path = meeting_data.get("_original_text_path", "")
original_text = meeting_data.get("_original_text", "")
base_metadata = {
"title": meeting_data.get("title", ""),
"date": meeting_data.get("date", ""),
"participants": ", ".join(meeting_data.get("participants", [])),
"type": "meeting",
"content_hash": meeting_data.get("_content_hash", ""),
"original_text_path": original_text_path,
"original_text_excerpt": original_text[:500] if original_text else "",
"meeting_id": meeting_id,
}
docs = self._build_field_docs(meeting_data, base_metadata, meeting_id)
if self._index:
for doc in docs:
self._index.insert(doc)
self._save()
logger.info(f"会议 '{meeting_data.get('title')}' 已添加到向量索引 (id={meeting_id}, 字段数={len(docs)})")
return True
except Exception as e:
logger.error(f"添加会议到向量索引失败: {e}")
return False
def _build_field_docs(self, data: dict, base: dict, meeting_id: str) -> List[Document]:
docs = []
header = f"# {data.get('title', '')}"
if data.get("date"):
header += f"\n日期: {data['date']}"
if data.get("participants"):
header += f"\n参会人: {', '.join(data['participants'])}"
docs.append(Document(text=header, metadata={**base, "field": "header"}, doc_id=f"{meeting_id}_header"))
if data.get("summary"):
docs.append(Document(text=data["summary"], metadata={**base, "field": "summary"}, doc_id=f"{meeting_id}_summary"))
if data.get("action_items"):
lines = []
for item in data["action_items"]:
status = item.get('status', '待办')
lines.append(f"- [{status}] {item.get('task', '')} (负责人: {item.get('assignee', '')}, 截止: {item.get('deadline', '')}, 优先级: {item.get('priority', '')})")
history = item.get("_history", [])
if len(history) > 1:
lines.append(" 演变: " + "".join(f"{h.get('date','')}({h.get('status','')})" for h in history))
docs.append(Document(text="\n".join(lines), metadata={**base, "field": "action_items"}, doc_id=f"{meeting_id}_action_items"))
if data.get("metrics"):
lines = []
for m in data["metrics"]:
lines.append(f"- {m.get('metric_name', '')}: {m.get('value', '')} (目标: {m.get('target', '')}, 趋势: {m.get('trend', '')})")
docs.append(Document(text="\n".join(lines), metadata={**base, "field": "metrics"}, doc_id=f"{meeting_id}_metrics"))
if data.get("decisions"):
lines = [f"- {d.get('content', '')}" for d in data["decisions"]]
docs.append(Document(text="\n".join(lines), metadata={**base, "field": "decisions"}, doc_id=f"{meeting_id}_decisions"))
if data.get("relations"):
lines = [f"- {r.get('subject', '')} --{r.get('predicate', '')}--> {r.get('object', '')}" for r in data["relations"]]
docs.append(Document(text="\n".join(lines), metadata={**base, "field": "relations"}, doc_id=f"{meeting_id}_relations"))
if data.get("entities"):
lines = [f"- [{e.get('entity_type', '')}] {e.get('name', '')}: {e.get('description', '')}" for e in data["entities"]]
docs.append(Document(text="\n".join(lines), metadata={**base, "field": "entities"}, doc_id=f"{meeting_id}_entities"))
return docs
def query(self, question: str, top_k: int = 5) -> List[dict]:
if not self._index:
return []
try:
retriever = self._index.as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(question)
results = []
for node in nodes:
results.append({
"text": node.text,
"score": node.score,
"metadata": node.metadata,
})
return results
except Exception as e:
logger.error(f"查询向量索引失败: {e}")
return []
def query_as_context(self, question: str, top_k: int = 3) -> str:
results = self.query(question, top_k=top_k)
if not results:
return ""
parts = []
for i, r in enumerate(results):
metadata = r.get("metadata", {})
parts.append(f"[{i+1}] {metadata.get('title', '未知会议')} ({metadata.get('date', '')})\n{r['text']}\n")
return "\n".join(parts)
def get_stats(self) -> dict:
if not self._index:
return {"doc_count": 0, "node_count": 0}
try:
docstore = self._index.docstore
docs = list(docstore.docs.values()) if hasattr(docstore, 'docs') else []
return {
"doc_count": len(docstore.docs) if hasattr(docstore, 'docs') else 0,
"node_count": len(docs),
}
except Exception:
return {"doc_count": 0, "node_count": 0}
meeting_vector_store = MeetingVectorStore()