|
|
@@ -6,11 +6,11 @@ from pathlib import Path
|
|
|
import pickle
|
|
|
import tempfile
|
|
|
import traceback
|
|
|
-from typing import ClassVar
|
|
|
+from typing import Any, ClassVar
|
|
|
import urllib.parse
|
|
|
from zoneinfo import ZoneInfo
|
|
|
from loguru import logger
|
|
|
-from attrs import frozen
|
|
|
+from attrs import field, frozen
|
|
|
from abc import ABC, abstractmethod
|
|
|
from uuid import UUID, uuid1
|
|
|
from hypercorn.asyncio import serve
|
|
|
@@ -303,40 +303,41 @@ class EmbeddingsWorker(Worker):
|
|
|
storage: Storage
|
|
|
model_name: str
|
|
|
batch_size: int
|
|
|
+ model: Any = field(init=False, default=None)
|
|
|
|
|
|
async def run(self):
|
|
|
def load_model():
|
|
|
- logger.info("Starting stuff")
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
- model = SentenceTransformer(self.model_name)
|
|
|
- return model
|
|
|
+ # Quite a dirty trick since the instance is supposed to be "frozen"
|
|
|
+ # but I did not find a better solution to load the model in the
|
|
|
+ # background
|
|
|
+ object.__setattr__(self, "model", SentenceTransformer(self.model_name))
|
|
|
|
|
|
- def compute_embeddings_for(model, sentences: tuple[tuple[int, str]]):
|
|
|
+ def compute_embeddings_for(sentences: tuple[tuple[int, str]]):
|
|
|
logger.debug(f"Computing embeddings for {len(sentences)} sentences")
|
|
|
all_texts = [t[1] for t in sentences]
|
|
|
- all_embeddings = model.encode(all_texts)
|
|
|
+ all_embeddings = self.model.encode(all_texts)
|
|
|
|
|
|
return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
|
|
|
|
|
|
- logger.info("Embeddings watchdog")
|
|
|
-
|
|
|
loop = asyncio.get_running_loop()
|
|
|
- model = await loop.run_in_executor(None, load_model)
|
|
|
- logger.info("Model loaded")
|
|
|
+ if self.model is None:
|
|
|
+ await loop.run_in_executor(None, load_model)
|
|
|
+
|
|
|
all_titles = [
|
|
|
(t["id"], t["text"])
|
|
|
for t in await self.storage.list_all_titles_without_embedding()
|
|
|
]
|
|
|
|
|
|
for batch in batched(all_titles, self.batch_size):
|
|
|
- embeddings = compute_embeddings_for(model, batch)
|
|
|
- logger.info(f"embeddings: {embeddings}")
|
|
|
+ embeddings = compute_embeddings_for(batch)
|
|
|
for i, embed in embeddings.items():
|
|
|
await self.storage.add_embedding(i, embed)
|
|
|
|
|
|
logger.debug(f"Stored {len(embeddings)} embeddings")
|
|
|
|
|
|
+ await asyncio.sleep(5)
|
|
|
|
|
|
@frozen
|
|
|
class WebServer(Worker):
|
|
|
@@ -384,7 +385,7 @@ async def main():
|
|
|
await SnapshotSearchJob.queue.put(j)
|
|
|
|
|
|
tasks.append(tg.create_task(web_server.run()))
|
|
|
- tasks.append(tg.create_task(embeds.run()))
|
|
|
+ tasks.append(tg.create_task(embeds.loop()))
|
|
|
finally:
|
|
|
await storage.close()
|
|
|
|