ソースを参照

Split embeddings service into job/worker

jherve 1 年間 前
コミット
b126f66969

+ 21 - 1
src/de_quoi_parle_le_monde/main.py

@@ -10,6 +10,10 @@ from de_quoi_parle_le_monde.web import app
 from de_quoi_parle_le_monde.http import HttpClient
 from de_quoi_parle_le_monde.storage import Storage
 from de_quoi_parle_le_monde.workers.snapshot import SnapshotJob, SnapshotWorker
+from de_quoi_parle_le_monde.workers.embeddings import (
+    EmbeddingsJob,
+    EmbeddingsWorker,
+)
 
 
 @frozen
@@ -20,7 +24,11 @@ class Application:
     web_config: Config
 
     async def run(self):
-        await asyncio.gather(self._run_web_server(), self._run_snapshot_worker())
+        await asyncio.gather(
+            self._run_web_server(),
+            self._run_snapshot_worker(),
+            self._run_embeddings_worker(),
+        )
         logger.info("Will quit now..")
 
     async def _run_web_server(self):
@@ -35,6 +43,18 @@ class Application:
             worker = SnapshotWorker.create(self.storage, session)
             await asyncio.gather(*[worker.run(job) for job in jobs])
 
+    async def _run_embeddings_worker(self):
+        logger.info("Starting embeddings service..")
+        jobs = await EmbeddingsJob.create(self.storage)
+        loop = asyncio.get_event_loop()
+        worker = await loop.run_in_executor(
+            None,
+            EmbeddingsWorker.create,
+            self.storage,
+            "dangvantuan/sentence-camembert-large",
+        )
+        await worker.run(jobs)
+
     @staticmethod
     async def create():
         http_client = HttpClient()

+ 28 - 21
src/de_quoi_parle_le_monde/workers/embeddings.py

@@ -1,6 +1,7 @@
 from typing import Any
 from attrs import frozen
 from loguru import logger
+from numpy.typing import NDArray
 
 from de_quoi_parle_le_monde.storage import Storage
 
@@ -22,6 +23,25 @@ def batched(iterable, n):
         yield batch
 
 
+@frozen
+class EmbeddingsJob:
+    article_id: int
+    text: NDArray
+
+    @staticmethod
+    async def create(storage: Storage):
+        all_snapshots = await storage.list_all_featured_article_snapshots()
+        all_embeds_ids = set(
+            await storage.list_all_embedded_featured_article_snapshot_ids()
+        )
+
+        all_snapshots_not_stored = (
+            s for s in all_snapshots if s["id"] not in all_embeds_ids
+        )
+
+        return [EmbeddingsJob(s["id"], s["title"]) for s in all_snapshots_not_stored]
+
+
 @frozen
 class EmbeddingsWorker:
     storage: Storage
@@ -48,30 +68,17 @@ class EmbeddingsWorker:
         for i, embed in embeddings_by_id.items():
             await self.storage.add_embedding(i, embed)
 
+    async def run(self, jobs: list[EmbeddingsJob]):
+        batch_size = 64
+        for batch in batched(jobs, batch_size):
+            embeddings_by_id = self.compute_embeddings_for(
+                {j.article_id: j.text for j in batch}
+            )
+            await self.store_embeddings(embeddings_by_id)
+
     @staticmethod
     def create(storage, model_path):
         from sentence_transformers import SentenceTransformer
 
         model = SentenceTransformer(model_path)
         return EmbeddingsWorker(storage, model)
-
-
-async def compute_embeddings(storage: Storage):
-    worker = EmbeddingsWorker.create(storage, "dangvantuan/sentence-camembert-large")
-    all_snapshots = await storage.list_all_featured_article_snapshots()
-    all_embeds_ids = set(
-        await storage.list_all_embedded_featured_article_snapshot_ids()
-    )
-
-    all_snapshots_not_stored = (
-        s for s in all_snapshots if s["id"] not in all_embeds_ids
-    )
-
-    batch_size = 64
-    for batch in batched(all_snapshots_not_stored, batch_size):
-        embeddings_by_id = worker.compute_embeddings_for(
-            {s["id"]: s["title"] for s in batch}
-        )
-        await worker.store_embeddings(embeddings_by_id)
-
-    return worker