|
|
@@ -1,4 +1,5 @@
|
|
|
import asyncio
|
|
|
+from itertools import islice
|
|
|
from datetime import date, datetime, time, timedelta
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
@@ -283,6 +284,60 @@ class StoreWorker(QueueWorker):
|
|
|
return {"storage": self.storage}
|
|
|
|
|
|
|
|
|
+def batched(iterable, n):
|
|
|
+ """
|
|
|
+ Batch data into tuples of length n. The last batch may be shorter.
|
|
|
+ `batched('ABCDEFG', 3) --> ABC DEF G`
|
|
|
+
|
|
|
+ Straight from : https://docs.python.org/3.11/library/itertools.html#itertools-recipes
|
|
|
+ """
|
|
|
+ if n < 1:
|
|
|
+ raise ValueError("n must be at least one")
|
|
|
+ it = iter(iterable)
|
|
|
+ while batch := tuple(islice(it, n)):
|
|
|
+ yield batch
|
|
|
+
|
|
|
+
|
|
|
+@frozen
|
|
|
+class EmbeddingsWorker(Worker):
|
|
|
+ storage: Storage
|
|
|
+ model_name: str
|
|
|
+ batch_size: int
|
|
|
+
|
|
|
+ async def run(self):
|
|
|
+ def load_model():
|
|
|
+ logger.info("Starting stuff")
|
|
|
+ from sentence_transformers import SentenceTransformer
|
|
|
+
|
|
|
+ model = SentenceTransformer(self.model_name)
|
|
|
+ return model
|
|
|
+
|
|
|
+ def compute_embeddings_for(model, sentences: tuple[tuple[int, str]]):
|
|
|
+ logger.debug(f"Computing embeddings for {len(sentences)} sentences")
|
|
|
+ all_texts = [t[1] for t in sentences]
|
|
|
+ all_embeddings = model.encode(all_texts)
|
|
|
+
|
|
|
+ return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
|
|
|
+
|
|
|
+ logger.info("Embeddings watchdog")
|
|
|
+
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ model = await loop.run_in_executor(None, load_model)
|
|
|
+ logger.info("Model loaded")
|
|
|
+ all_titles = [
|
|
|
+ (t["id"], t["text"])
|
|
|
+ for t in await self.storage.list_all_titles_without_embedding()
|
|
|
+ ]
|
|
|
+
|
|
|
+ for batch in batched(all_titles, self.batch_size):
|
|
|
+ embeddings = compute_embeddings_for(model, batch)
|
|
|
+ logger.info(f"embeddings: {embeddings}")
|
|
|
+ for i, embed in embeddings.items():
|
|
|
+ await self.storage.add_embedding(i, embed)
|
|
|
+
|
|
|
+ logger.debug(f"Stored {len(embeddings)} embeddings")
|
|
|
+
|
|
|
+
|
|
|
@frozen
|
|
|
class WebServer(Worker):
|
|
|
async def run(self):
|
|
|
@@ -317,6 +372,11 @@ async def main():
|
|
|
"store": StoreWorker(SnapshotStoreJob.queue, None, storage),
|
|
|
}
|
|
|
web_server = WebServer()
|
|
|
+ embeds = EmbeddingsWorker(
|
|
|
+ storage,
|
|
|
+ "dangvantuan/sentence-camembert-large",
|
|
|
+ 64,
|
|
|
+ )
|
|
|
async with asyncio.TaskGroup() as tg:
|
|
|
for w in workers.values():
|
|
|
tasks.append(tg.create_task(w.loop()))
|
|
|
@@ -324,6 +384,7 @@ async def main():
|
|
|
await SnapshotSearchJob.queue.put(j)
|
|
|
|
|
|
tasks.append(tg.create_task(web_server.run()))
|
|
|
+ tasks.append(tg.create_task(embeds.run()))
|
|
|
finally:
|
|
|
await storage.close()
|
|
|
|