|
|
@@ -1,6 +1,7 @@
|
|
|
from typing import Any
|
|
|
from attrs import frozen
|
|
|
from loguru import logger
|
|
|
+from numpy.typing import NDArray
|
|
|
|
|
|
from de_quoi_parle_le_monde.storage import Storage
|
|
|
|
|
|
@@ -22,6 +23,25 @@ def batched(iterable, n):
|
|
|
yield batch
|
|
|
|
|
|
|
|
|
+@frozen
|
|
|
+class EmbeddingsJob:
|
|
|
+ article_id: int
|
|
|
+ text: NDArray
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ async def create(storage: Storage):
|
|
|
+ all_snapshots = await storage.list_all_featured_article_snapshots()
|
|
|
+ all_embeds_ids = set(
|
|
|
+ await storage.list_all_embedded_featured_article_snapshot_ids()
|
|
|
+ )
|
|
|
+
|
|
|
+ all_snapshots_not_stored = (
|
|
|
+ s for s in all_snapshots if s["id"] not in all_embeds_ids
|
|
|
+ )
|
|
|
+
|
|
|
+ return [EmbeddingsJob(s["id"], s["title"]) for s in all_snapshots_not_stored]
|
|
|
+
|
|
|
+
|
|
|
@frozen
|
|
|
class EmbeddingsWorker:
|
|
|
storage: Storage
|
|
|
@@ -48,30 +68,17 @@ class EmbeddingsWorker:
|
|
|
for i, embed in embeddings_by_id.items():
|
|
|
await self.storage.add_embedding(i, embed)
|
|
|
|
|
|
+ async def run(self, jobs: list[EmbeddingsJob]):
|
|
|
+ batch_size = 64
|
|
|
+ for batch in batched(jobs, batch_size):
|
|
|
+ embeddings_by_id = self.compute_embeddings_for(
|
|
|
+ {j.article_id: j.text for j in batch}
|
|
|
+ )
|
|
|
+ await self.store_embeddings(embeddings_by_id)
|
|
|
+
|
|
|
@staticmethod
|
|
|
def create(storage, model_path):
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
model = SentenceTransformer(model_path)
|
|
|
return EmbeddingsWorker(storage, model)
|
|
|
-
|
|
|
-
|
|
|
-async def compute_embeddings(storage: Storage):
|
|
|
- worker = EmbeddingsWorker.create(storage, "dangvantuan/sentence-camembert-large")
|
|
|
- all_snapshots = await storage.list_all_featured_article_snapshots()
|
|
|
- all_embeds_ids = set(
|
|
|
- await storage.list_all_embedded_featured_article_snapshot_ids()
|
|
|
- )
|
|
|
-
|
|
|
- all_snapshots_not_stored = (
|
|
|
- s for s in all_snapshots if s["id"] not in all_embeds_ids
|
|
|
- )
|
|
|
-
|
|
|
- batch_size = 64
|
|
|
- for batch in batched(all_snapshots_not_stored, batch_size):
|
|
|
- embeddings_by_id = worker.compute_embeddings_for(
|
|
|
- {s["id"]: s["title"] for s in batch}
|
|
|
- )
|
|
|
- await worker.store_embeddings(embeddings_by_id)
|
|
|
-
|
|
|
- return worker
|