Просмотр исходного кода

Make embeddings worker run in loop

jherve 1 год назад
Родитель
Сommit
42e95cfc2b
1 измененных файлов с 15 добавлено и 14 удалено
  1. 15 14
      src/media_observer/test.py

+ 15 - 14
src/media_observer/test.py

@@ -6,11 +6,11 @@ from pathlib import Path
 import pickle
 import pickle
 import tempfile
 import tempfile
 import traceback
 import traceback
-from typing import ClassVar
+from typing import Any, ClassVar
 import urllib.parse
 import urllib.parse
 from zoneinfo import ZoneInfo
 from zoneinfo import ZoneInfo
 from loguru import logger
 from loguru import logger
-from attrs import frozen
+from attrs import field, frozen
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from uuid import UUID, uuid1
 from uuid import UUID, uuid1
 from hypercorn.asyncio import serve
 from hypercorn.asyncio import serve
@@ -303,40 +303,41 @@ class EmbeddingsWorker(Worker):
     storage: Storage
     storage: Storage
     model_name: str
     model_name: str
     batch_size: int
     batch_size: int
+    model: Any = field(init=False, default=None)
 
 
     async def run(self):
     async def run(self):
         def load_model():
         def load_model():
-            logger.info("Starting stuff")
             from sentence_transformers import SentenceTransformer
             from sentence_transformers import SentenceTransformer
 
 
-            model = SentenceTransformer(self.model_name)
-            return model
+            # Quite a dirty trick since the instance is supposed to be "frozen"
+            # but I did not find a better solution to load the model in the
+            # background
+            object.__setattr__(self, "model", SentenceTransformer(self.model_name))
 
 
-        def compute_embeddings_for(model, sentences: tuple[tuple[int, str]]):
+        def compute_embeddings_for(sentences: tuple[tuple[int, str]]):
             logger.debug(f"Computing embeddings for {len(sentences)} sentences")
             logger.debug(f"Computing embeddings for {len(sentences)} sentences")
             all_texts = [t[1] for t in sentences]
             all_texts = [t[1] for t in sentences]
-            all_embeddings = model.encode(all_texts)
+            all_embeddings = self.model.encode(all_texts)
 
 
             return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
             return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
 
 
-        logger.info("Embeddings watchdog")
-
         loop = asyncio.get_running_loop()
         loop = asyncio.get_running_loop()
-        model = await loop.run_in_executor(None, load_model)
-        logger.info("Model loaded")
+        if self.model is None:
+            await loop.run_in_executor(None, load_model)
+
         all_titles = [
         all_titles = [
             (t["id"], t["text"])
             (t["id"], t["text"])
             for t in await self.storage.list_all_titles_without_embedding()
             for t in await self.storage.list_all_titles_without_embedding()
         ]
         ]
 
 
         for batch in batched(all_titles, self.batch_size):
         for batch in batched(all_titles, self.batch_size):
-            embeddings = compute_embeddings_for(model, batch)
-            logger.info(f"embeddings: {embeddings}")
+            embeddings = compute_embeddings_for(batch)
             for i, embed in embeddings.items():
             for i, embed in embeddings.items():
                 await self.storage.add_embedding(i, embed)
                 await self.storage.add_embedding(i, embed)
 
 
             logger.debug(f"Stored {len(embeddings)} embeddings")
             logger.debug(f"Stored {len(embeddings)} embeddings")
 
 
+        await asyncio.sleep(5)
 
 
 @frozen
 @frozen
 class WebServer(Worker):
 class WebServer(Worker):
@@ -384,7 +385,7 @@ async def main():
                     await SnapshotSearchJob.queue.put(j)
                     await SnapshotSearchJob.queue.put(j)
 
 
                 tasks.append(tg.create_task(web_server.run()))
                 tasks.append(tg.create_task(web_server.run()))
-                tasks.append(tg.create_task(embeds.run()))
+                tasks.append(tg.create_task(embeds.loop()))
     finally:
     finally:
         await storage.close()
         await storage.close()