Jelajahi Sumber

Make embeddings worker run in loop

jherve 1 tahun lalu
induk
melakukan
42e95cfc2b
1 mengubah file dengan 15 tambahan dan 14 penghapusan
  1. 15 14
      src/media_observer/test.py

+ 15 - 14
src/media_observer/test.py

@@ -6,11 +6,11 @@ from pathlib import Path
 import pickle
 import tempfile
 import traceback
-from typing import ClassVar
+from typing import Any, ClassVar
 import urllib.parse
 from zoneinfo import ZoneInfo
 from loguru import logger
-from attrs import frozen
+from attrs import field, frozen
 from abc import ABC, abstractmethod
 from uuid import UUID, uuid1
 from hypercorn.asyncio import serve
@@ -303,40 +303,41 @@ class EmbeddingsWorker(Worker):
     storage: Storage
     model_name: str
     batch_size: int
+    model: Any = field(init=False, default=None)
 
     async def run(self):
         def load_model():
-            logger.info("Starting stuff")
             from sentence_transformers import SentenceTransformer
 
-            model = SentenceTransformer(self.model_name)
-            return model
+            # Quite a dirty trick since the instance is supposed to be "frozen"
+            # but I did not find a better solution to load the model in the
+            # background
+            object.__setattr__(self, "model", SentenceTransformer(self.model_name))
 
-        def compute_embeddings_for(model, sentences: tuple[tuple[int, str]]):
+        def compute_embeddings_for(sentences: tuple[tuple[int, str]]):
             logger.debug(f"Computing embeddings for {len(sentences)} sentences")
             all_texts = [t[1] for t in sentences]
-            all_embeddings = model.encode(all_texts)
+            all_embeddings = self.model.encode(all_texts)
 
             return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
 
-        logger.info("Embeddings watchdog")
-
         loop = asyncio.get_running_loop()
-        model = await loop.run_in_executor(None, load_model)
-        logger.info("Model loaded")
+        if self.model is None:
+            await loop.run_in_executor(None, load_model)
+
         all_titles = [
             (t["id"], t["text"])
             for t in await self.storage.list_all_titles_without_embedding()
         ]
 
         for batch in batched(all_titles, self.batch_size):
-            embeddings = compute_embeddings_for(model, batch)
-            logger.info(f"embeddings: {embeddings}")
+            embeddings = compute_embeddings_for(batch)
             for i, embed in embeddings.items():
                 await self.storage.add_embedding(i, embed)
 
             logger.debug(f"Stored {len(embeddings)} embeddings")
 
+        await asyncio.sleep(5)
 
 @frozen
 class WebServer(Worker):
@@ -384,7 +385,7 @@ async def main():
                     await SnapshotSearchJob.queue.put(j)
 
                 tasks.append(tg.create_task(web_server.run()))
-                tasks.append(tg.create_task(embeds.run()))
+                tasks.append(tg.create_task(embeds.loop()))
     finally:
         await storage.close()