|
|
@@ -1,4 +1,5 @@
|
|
|
import asyncio
|
|
|
+import concurrent.futures
|
|
|
from itertools import islice
|
|
|
from datetime import date, datetime, time, timedelta
|
|
|
import os
|
|
|
@@ -343,13 +344,6 @@ class EmbeddingsWorker(Worker):
|
|
|
|
|
|
self.model = SentenceTransformer(self.model_name)
|
|
|
|
|
|
- def compute_embeddings_for(sentences: tuple[tuple[int, str]]):
|
|
|
- logger.debug(f"Computing embeddings for {len(sentences)} sentences")
|
|
|
- all_texts = [t[1] for t in sentences]
|
|
|
- all_embeddings = self.model.encode(all_texts)
|
|
|
-
|
|
|
- return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
|
|
|
-
|
|
|
while True:
|
|
|
loop = asyncio.get_running_loop()
|
|
|
if self.model is None:
|
|
|
@@ -361,7 +355,10 @@ class EmbeddingsWorker(Worker):
|
|
|
]
|
|
|
|
|
|
for batch in batched(all_titles, self.batch_size):
|
|
|
- embeddings = compute_embeddings_for(batch)
|
|
|
+ with concurrent.futures.ProcessPoolExecutor(max_workers=1) as pool:
|
|
|
+ embeddings = await loop.run_in_executor(
|
|
|
+ pool, self.compute_embeddings_for, self.model, batch
|
|
|
+ )
|
|
|
for i, embed in embeddings.items():
|
|
|
await self.storage.add_embedding(i, embed)
|
|
|
|
|
|
@@ -372,6 +369,14 @@ class EmbeddingsWorker(Worker):
|
|
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def compute_embeddings_for(model: Any, sentences: tuple[tuple[int, str]]):
|
|
|
+ logger.debug(f"Computing embeddings for {len(sentences)} sentences")
|
|
|
+ all_texts = [t[1] for t in sentences]
|
|
|
+ all_embeddings = model.encode(all_texts)
|
|
|
+
|
|
|
+ return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
|
|
|
+
|
|
|
|
|
|
@frozen
|
|
|
class SimilarityIndexWorker(Worker):
|