This repository was archived by the owner on May 15, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembed.py
More file actions
104 lines (90 loc) · 3.34 KB
/
embed.py
File metadata and controls
104 lines (90 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import json
import chromadb
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from chromadb.utils import embedding_functions
from chromadb.config import Settings
from progress.bar import IncrementalBar
def count_strings_in_range(string_array, max_length):
count_array = [0] * ((max_length - 0) + 1)
for string in string_array:
length = len(string)
if 0 <= length <= max_length:
count_array[length - 0] += 1
return count_array
class Embedder:
def __init__(self, dataset, persist_directory):
self.chunks = []
self.dataset = dataset
self.client = chromadb.Client(
Settings(
chroma_db_impl="duckdb+parquet", persist_directory=persist_directory
)
)
def preprocess(self, min_text_length, max_chunk_size, chunk_overlap):
with open(self.dataset) as r:
docs_raw = list(json.loads(r.read()))
docs_clean = []
for e in docs_raw:
if "text" not in e:
continue
if not e["text"]:
continue
if e["text"] == "":
continue
if len(e["text"]) < min_text_length:
continue
docs_clean.append(e)
print(f"\n* Filtered {len(docs_raw)} sources down to {len(docs_clean)}")
cleaned_docs = [
Document(
page_content=doc["text"],
metadata={
"id": doc["id"],
"title": doc.get("title", ""),
},
)
for doc in docs_clean
]
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=max_chunk_size, chunk_overlap=chunk_overlap, length_function=len
)
self.chunks = [
(str(index), chunk)
for index, chunk in enumerate(text_splitter.split_documents(cleaned_docs))
if len(chunk.page_content) > min_text_length
]
# plot_hist(
# [len(chunk[1].page_content) for chunk in self.chunks],
# bincount=64,
# binwidth=8,
# xlab="Length of chunk",
# showSummary=True,
# )
return self
def embed(self, name, embedding_function):
collection = self.client.create_collection(
name, embedding_function=embedding_function
)
with IncrementalBar(
f"* Embedding {len(self.chunks)} chunks",
suffix="%(percent).1f%% - %(elapsed)ds",
max=len(self.chunks),
) as bar:
for chunk in self.chunks:
collection.add(
ids=chunk[0],
documents=chunk[1].page_content,
metadatas={
"id": chunk[1].metadata["id"],
"title": chunk[1].metadata["title"],
},
)
bar.next()
print("* Finished embedding documents")
embed_func = embedding_functions.InstructorEmbeddingFunction(
model_name="hkunlp/instructor-large", device="cpu"
)
embedder = Embedder("local/knowledge.json", "local/embeddings")
embedder.preprocess(80, 512, 128)
embedder.embed("general-max-size-512", embed_func)