group-wbl/knowledge_base.py
2026-01-08 14:11:42 +08:00

523 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/Users/bzbb/Documents/work/1/test/.venv/bin/python3
import os
import uuid
import time
import chromadb
import pypdf
import docx2txt
import jieba
from chromadb.config import Settings
from typing import List, Dict, Any, Optional
from datetime import datetime
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sklearn.feature_extraction.text import TfidfVectorizer
from whoosh.index import create_in, open_dir
from whoosh.fields import Schema, TEXT, ID, STORED
from whoosh.qparser import MultifieldParser
from whoosh.analysis import Tokenizer, Token
import tempfile
import shutil
# 自定义中文分词器类
class ChineseTokenizer(Tokenizer):
def __call__(self, text, **kwargs):
from jieba import cut
t = Token()
pos = 0
for word in cut(text):
t.text = word
t.boost = 1.0
t.start_pos = text.find(word, pos)
t.end_pos = t.start_pos + len(word)
t.pos = pos # 添加pos属性
yield t
pos = t.end_pos
def ChineseAnalyzer():
return ChineseTokenizer()
class KnowledgeBase:
"""智能知识库核心类,管理文档和向量数据库"""
def __init__(self, persist_directory: str = "./knowledge_base"):
"""初始化知识库
Args:
persist_directory: 向量数据库持久化目录
"""
self.persist_directory = persist_directory
# 初始化 ChromaDB 客户端
self.client = chromadb.PersistentClient(
path=persist_directory
)
# 创建或获取集合
self.collection = self.client.get_or_create_collection(
name="documents",
metadata={"description": "智能知识库文档集合"}
)
# 文档元数据存储
self.document_metadata = {}
self.load_metadata()
# 初始化文本分割器
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
separators=["\n\n", "\n", "", "", "", "", "", " "]
)
# 初始化稀疏检索Whoosh索引
self.sparse_index_dir = os.path.join(persist_directory, "sparse_index")
self.schema = Schema(
doc_id=ID(stored=True),
content=TEXT(stored=True, analyzer=ChineseAnalyzer()),
title=TEXT(stored=True),
file_path=STORED,
timestamp=STORED
)
# 创建或打开Whoosh索引
if not os.path.exists(self.sparse_index_dir):
os.makedirs(self.sparse_index_dir)
self.sparse_index = create_in(self.sparse_index_dir, self.schema)
else:
self.sparse_index = open_dir(self.sparse_index_dir)
# 初始化TF-IDF向量器
self.tfidf_vectorizer = TfidfVectorizer(tokenizer=jieba.cut, use_idf=True)
def load_metadata(self):
"""加载文档元数据"""
# 在实际应用中,应该从持久化存储中加载
self.document_metadata = {}
def save_metadata(self):
"""保存文档元数据"""
# 在实际应用中,应该保存到持久化存储
pass
def parse_document(self, file_path: str) -> Dict[str, Any]:
"""解析不同格式的文档
Args:
file_path: 文件路径
Returns:
解析结果包含content和metadata
"""
file_extension = os.path.splitext(file_path)[1].lower()
content = ""
metadata = {
"file_path": file_path,
"file_type": file_extension,
"parsed_at": datetime.now().isoformat()
}
try:
if file_extension == ".pdf":
# 解析PDF
with open(file_path, "rb") as f:
reader = pypdf.PdfReader(f)
for page in reader.pages:
content += page.extract_text() or ""
metadata["num_pages"] = len(reader.pages)
metadata["title"] = reader.metadata.title if reader.metadata and reader.metadata.title else os.path.basename(file_path)
elif file_extension in [".doc", ".docx"]:
# 解析Word
content = docx2txt.process(file_path)
metadata["title"] = os.path.basename(file_path)
elif file_extension == ".txt":
# 解析纯文本
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
metadata["title"] = os.path.basename(file_path)
else:
raise ValueError(f"不支持的文件格式: {file_extension}")
except Exception as e:
raise Exception(f"文档解析失败: {str(e)}")
return {"content": content, "metadata": metadata}
def process_document(self, file_path: str, metadata: Dict = None) -> List[Dict[str, Any]]:
"""处理文档的完整ETL流程解析、分割、向量化
Args:
file_path: 文件路径
metadata: 额外的元数据
Returns:
处理后的文档块列表
"""
# 1. 解析文档
parsed_result = self.parse_document(file_path)
content = parsed_result["content"]
doc_metadata = parsed_result["metadata"]
# 合并额外元数据
if metadata:
doc_metadata.update(metadata)
# 2. 文本分割
chunks = self.text_splitter.split_text(content)
# 3. 处理每个文档块
processed_chunks = []
for i, chunk in enumerate(chunks):
chunk_id = str(uuid.uuid4())
chunk_metadata = {
"id": chunk_id,
"parent_file": file_path,
"chunk_index": i,
"total_chunks": len(chunks),
"timestamp": datetime.now().isoformat(),
"version": 1,
**doc_metadata
}
processed_chunks.append({
"id": chunk_id,
"content": chunk,
"metadata": chunk_metadata
})
return processed_chunks
def add_document(self, content: str = None, file_path: str = None, metadata: Dict = None) -> List[str]:
"""添加文档到知识库
Args:
content: 文档内容如果提供file_path则可选
file_path: 文件路径如果提供content则可选
metadata: 文档元数据
Returns:
文档块ID列表
"""
if not content and not file_path:
raise ValueError("必须提供content或file_path")
processed_chunks = []
if file_path:
# 通过文件路径处理文档
processed_chunks = self.process_document(file_path, metadata)
else:
# 直接处理内容
chunks = self.text_splitter.split_text(content)
doc_metadata = {
"timestamp": datetime.now().isoformat(),
"version": 1,
"file_type": "text",
"title": metadata.get("title", "直接输入内容") if metadata else "直接输入内容",
**(metadata or {})
}
for i, chunk in enumerate(chunks):
chunk_id = str(uuid.uuid4())
chunk_metadata = {
"id": chunk_id,
"parent_file": "direct_input",
"chunk_index": i,
"total_chunks": len(chunks),
**doc_metadata
}
processed_chunks.append({
"id": chunk_id,
"content": chunk,
"metadata": chunk_metadata
})
# 添加到向量数据库和稀疏索引
chunk_ids = []
# 1. 准备所有数据用于批量添加
all_documents = []
all_metadatas = []
all_ids = []
for chunk in processed_chunks:
chunk_id = chunk["id"]
chunk_content = chunk["content"]
chunk_metadata = chunk["metadata"]
all_documents.append(chunk_content)
all_metadatas.append(chunk_metadata)
all_ids.append(chunk_id)
# 保存元数据
self.document_metadata[chunk_id] = chunk_metadata
chunk_ids.append(chunk_id)
# 2. 批量添加到向量数据库
if all_documents:
self.collection.add(
documents=all_documents,
metadatas=all_metadatas,
ids=all_ids
)
# 3. 批量添加到稀疏索引Whoosh
writer = self.sparse_index.writer()
for i in range(len(processed_chunks)):
chunk = processed_chunks[i]
writer.add_document(
doc_id=all_ids[i],
content=all_documents[i],
title=all_metadatas[i].get("title", ""),
file_path=all_metadatas[i].get("file_path", ""),
timestamp=all_metadatas[i].get("timestamp", "")
)
writer.commit()
self.save_metadata()
return chunk_ids
def update_document(self, document_id: str, content: str = None, file_path: str = None, metadata: Dict = None) -> bool:
"""更新文档
Args:
document_id: 文档块ID
content: 新的文档内容如果提供file_path则可选
file_path: 新的文件路径如果提供content则可选
metadata: 新的元数据
Returns:
是否更新成功
"""
if document_id not in self.document_metadata:
return False
current_metadata = self.document_metadata[document_id]
# 1. 删除旧的文档块
self.delete_document(document_id)
# 2. 添加新的文档内容或文件
if content:
chunks = self.text_splitter.split_text(content)
for i, chunk in enumerate(chunks):
chunk_id = document_id if i == 0 else str(uuid.uuid4())
new_metadata = {
"id": chunk_id,
"parent_file": current_metadata.get("parent_file"),
"chunk_index": i,
"total_chunks": len(chunks),
"timestamp": datetime.now().isoformat(),
"version": current_metadata.get("version", 1) + 1,
**current_metadata,
**(metadata or {})
}
# 添加到向量数据库
self.collection.add(
documents=[chunk],
metadatas=[new_metadata],
ids=[chunk_id]
)
# 添加到稀疏索引
writer = self.sparse_index.writer()
writer.add_document(
doc_id=chunk_id,
content=chunk,
title=new_metadata.get("title", ""),
file_path=new_metadata.get("file_path", ""),
timestamp=new_metadata.get("timestamp", "")
)
writer.commit()
# 保存元数据
self.document_metadata[chunk_id] = new_metadata
elif file_path:
# 重新处理文件
processed_chunks = self.process_document(file_path, {
"version": current_metadata.get("version", 1) + 1,
**current_metadata,
**(metadata or {})
})
for chunk in processed_chunks:
chunk_id = document_id if "chunk_index" in chunk["metadata"] and chunk["metadata"]["chunk_index"] == 0 else str(uuid.uuid4())
chunk["metadata"]["id"] = chunk_id
chunk["metadata"]["version"] = current_metadata.get("version", 1) + 1
# 添加到向量数据库
self.collection.add(
documents=[chunk["content"]],
metadatas=[chunk["metadata"]],
ids=[chunk_id]
)
# 添加到稀疏索引
writer = self.sparse_index.writer()
writer.add_document(
doc_id=chunk_id,
content=chunk["content"],
title=chunk["metadata"].get("title", ""),
file_path=chunk["metadata"].get("file_path", ""),
timestamp=chunk["metadata"].get("timestamp", "")
)
writer.commit()
# 保存元数据
self.document_metadata[chunk_id] = chunk["metadata"]
self.save_metadata()
return True
def delete_document(self, document_id: str) -> bool:
"""删除文档
Args:
document_id: 文档块ID
Returns:
是否删除成功
"""
if document_id not in self.document_metadata:
return False
# 1. 从向量数据库删除
self.collection.delete(ids=[document_id])
# 2. 从稀疏索引删除
writer = self.sparse_index.writer()
writer.delete_by_term("doc_id", document_id)
writer.commit()
# 3. 删除元数据
del self.document_metadata[document_id]
self.save_metadata()
return True
def get_document(self, document_id: str) -> Optional[Dict[str, Any]]:
"""获取文档信息
Args:
document_id: 文档ID
Returns:
文档信息
"""
return self.document_metadata.get(document_id)
def list_documents(self) -> List[Dict[str, Any]]:
"""列出所有文档
Returns:
文档列表
"""
return list(self.document_metadata.values())
def search(self, query: str, n_results: int = 5, hybrid_weight: float = 0.5) -> List[Dict[str, Any]]:
"""混合搜索文档(密集向量+稀疏关键词)
Args:
query: 搜索查询
n_results: 返回结果数量
hybrid_weight: 混合权重0=纯稀疏1=纯密集)
Returns:
搜索结果列表(已重排序)
"""
# 1. 密集向量搜索ChromaDB
dense_results = self.collection.query(
query_texts=[query],
n_results=n_results * 2, # 获取更多结果用于重排序
include=["documents", "metadatas", "distances"]
)
# 格式化密集搜索结果
dense_formatted = []
for i in range(len(dense_results["ids"][0])):
dense_formatted.append({
"id": dense_results["ids"][0][i],
"content": dense_results["documents"][0][i],
"metadata": dense_results["metadatas"][0][i],
"dense_score": 1.0 / (1.0 + dense_results["distances"][0][i]), # 转换为相似度分数
"sparse_score": 0.0
})
# 2. 稀疏关键词搜索Whoosh
sparse_results = []
with self.sparse_index.searcher() as searcher:
parser = MultifieldParser(["content", "title"], schema=self.schema)
whoosh_query = parser.parse(query)
whoosh_results = searcher.search(whoosh_query, limit=n_results * 2)
for result in whoosh_results:
doc_id = result["doc_id"]
if doc_id in self.document_metadata:
sparse_results.append({
"id": doc_id,
"content": result["content"],
"metadata": self.document_metadata[doc_id],
"dense_score": 0.0,
"sparse_score": result.score
})
# 3. 合并结果
all_results = {}
# 添加密集搜索结果
for result in dense_formatted:
all_results[result["id"]] = result
# 添加或更新稀疏搜索结果
for result in sparse_results:
if result["id"] in all_results:
all_results[result["id"]]["sparse_score"] = result["sparse_score"]
else:
all_results[result["id"]] = result
# 4. 重排序(混合分数)
def calculate_hybrid_score(result):
return (hybrid_weight * result["dense_score"]) + ((1 - hybrid_weight) * result["sparse_score"])
sorted_results = sorted(
all_results.values(),
key=calculate_hybrid_score,
reverse=True
)[:n_results]
# 5. 格式化最终结果
final_results = []
for result in sorted_results:
final_results.append({
"id": result["id"],
"content": result["content"],
"metadata": result["metadata"],
"dense_score": result["dense_score"],
"sparse_score": result["sparse_score"],
"hybrid_score": calculate_hybrid_score(result)
})
return final_results
def clear(self):
"""清空知识库"""
self.collection.delete()
self.document_metadata = {}
self.save_metadata()
# 创建全局知识库实例
global_knowledge_base = None
def get_knowledge_base() -> KnowledgeBase:
"""获取知识库实例(单例模式)"""
global global_knowledge_base
if global_knowledge_base is None:
global_knowledge_base = KnowledgeBase()
return global_knowledge_base