Преглед изворни кода

Execute embeddings calculation in another process

jherve пре 1 година
родитељ
комит
625336da4b
1 измењених фајлова са 13 додато и 8 уклоњено
  1. 13 8
      src/media_observer/test.py

+ 13 - 8
src/media_observer/test.py

@@ -1,4 +1,5 @@
 import asyncio
+import concurrent.futures
 from itertools import islice
 from datetime import date, datetime, time, timedelta
 import os
@@ -343,13 +344,6 @@ class EmbeddingsWorker(Worker):
 
             self.model = SentenceTransformer(self.model_name)
 
-        def compute_embeddings_for(sentences: tuple[tuple[int, str]]):
-            logger.debug(f"Computing embeddings for {len(sentences)} sentences")
-            all_texts = [t[1] for t in sentences]
-            all_embeddings = self.model.encode(all_texts)
-
-            return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
-
         while True:
             loop = asyncio.get_running_loop()
             if self.model is None:
@@ -361,7 +355,10 @@ class EmbeddingsWorker(Worker):
             ]
 
             for batch in batched(all_titles, self.batch_size):
-                embeddings = compute_embeddings_for(batch)
+                with concurrent.futures.ProcessPoolExecutor(max_workers=1) as pool:
+                    embeddings = await loop.run_in_executor(
+                        pool, self.compute_embeddings_for, self.model, batch
+                    )
                 for i, embed in embeddings.items():
                     await self.storage.add_embedding(i, embed)
 
@@ -372,6 +369,14 @@ class EmbeddingsWorker(Worker):
 
             await asyncio.sleep(5)
 
+    @staticmethod
+    def compute_embeddings_for(model: Any, sentences: tuple[tuple[int, str]]):
+        logger.debug(f"Computing embeddings for {len(sentences)} sentences")
+        all_texts = [t[1] for t in sentences]
+        all_embeddings = model.encode(all_texts)
+
+        return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
+
 
 @frozen
 class SimilarityIndexWorker(Worker):