|
|
@@ -1,12 +1,11 @@
|
|
|
import asyncio
|
|
|
-from loguru import logger
|
|
|
+import concurrent.futures
|
|
|
from itertools import islice
|
|
|
-from collections import defaultdict
|
|
|
from typing import Any
|
|
|
-from attrs import frozen
|
|
|
-from numpy.typing import NDArray
|
|
|
-from sentence_transformers import SentenceTransformer
|
|
|
+from loguru import logger
|
|
|
+from attrs import define, field
|
|
|
|
|
|
+from media_observer.worker import Worker
|
|
|
from media_observer.storage import Storage
|
|
|
|
|
|
|
|
|
@@ -24,74 +23,49 @@ def batched(iterable, n):
|
|
|
yield batch
|
|
|
|
|
|
|
|
|
-@frozen
|
|
|
-class EmbeddingsJob:
|
|
|
- title_id: int
|
|
|
- text: NDArray
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- async def create(storage: Storage):
|
|
|
- all_titles = await storage.list_all_titles_without_embedding()
|
|
|
- return [EmbeddingsJob(t["id"], t["text"]) for t in all_titles]
|
|
|
+@define
|
|
|
+class EmbeddingsWorker(Worker):
|
|
|
+ storage: Storage
|
|
|
+ model_name: str
|
|
|
+ batch_size: int
|
|
|
+ new_embeddings_event: asyncio.Event
|
|
|
+ model: Any = field(init=False, default=None)
|
|
|
|
|
|
+ async def run(self):
|
|
|
+ def load_model():
|
|
|
+ from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
-@frozen
|
|
|
-class EmbeddingsWorker:
|
|
|
- storage: Storage
|
|
|
- model: Any
|
|
|
+ self.model = SentenceTransformer(self.model_name)
|
|
|
|
|
|
- def compute_embeddings_for(self, sentences: dict[int, str]):
|
|
|
- logger.debug(f"Computing embeddings for {len(sentences)} sentences")
|
|
|
- inverted_dict = defaultdict(list)
|
|
|
- for idx, (k, v) in enumerate(list(sentences.items())):
|
|
|
- inverted_dict[v].append((idx, k))
|
|
|
- all_texts = list(inverted_dict.keys())
|
|
|
- all_embeddings = self.model.encode(all_texts)
|
|
|
-
|
|
|
- embeddings_by_id = {}
|
|
|
- for e, text in zip(all_embeddings, all_texts):
|
|
|
- all_ids = [id for (_, id) in inverted_dict[text]]
|
|
|
- for i in all_ids:
|
|
|
- embeddings_by_id[i] = e
|
|
|
-
|
|
|
- return embeddings_by_id
|
|
|
-
|
|
|
- async def store_embeddings(self, embeddings_by_id: dict):
|
|
|
- logger.debug(f"Storing {len(embeddings_by_id)} embeddings")
|
|
|
- for i, embed in embeddings_by_id.items():
|
|
|
- await self.storage.add_embedding(i, embed)
|
|
|
-
|
|
|
- async def run(self, jobs: list[EmbeddingsJob]):
|
|
|
- batch_size = 64
|
|
|
- for batch in batched(jobs, batch_size):
|
|
|
- embeddings_by_id = self.compute_embeddings_for(
|
|
|
- {j.title_id: j.text for j in batch}
|
|
|
- )
|
|
|
- await self.store_embeddings(embeddings_by_id)
|
|
|
+ while True:
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ if self.model is None:
|
|
|
+ await loop.run_in_executor(None, load_model)
|
|
|
|
|
|
- @staticmethod
|
|
|
- def create(storage, model_path):
|
|
|
- model = SentenceTransformer(model_path)
|
|
|
- return EmbeddingsWorker(storage, model)
|
|
|
+ all_titles = [
|
|
|
+ (t["id"], t["text"])
|
|
|
+ for t in await self.storage.list_all_titles_without_embedding()
|
|
|
+ ]
|
|
|
|
|
|
+ for batch in batched(all_titles, self.batch_size):
|
|
|
+ with concurrent.futures.ProcessPoolExecutor(max_workers=1) as pool:
|
|
|
+ embeddings = await loop.run_in_executor(
|
|
|
+ pool, self.compute_embeddings_for, self.model, batch
|
|
|
+ )
|
|
|
+ for i, embed in embeddings.items():
|
|
|
+ await self.storage.add_embedding(i, embed)
|
|
|
|
|
|
-async def main():
|
|
|
- storage = await Storage.create()
|
|
|
+ logger.debug(f"Stored {len(embeddings)} embeddings")
|
|
|
|
|
|
- logger.info("Starting embeddings service..")
|
|
|
- jobs = await EmbeddingsJob.create(storage)
|
|
|
- if jobs:
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
- worker = await loop.run_in_executor(
|
|
|
- None,
|
|
|
- EmbeddingsWorker.create,
|
|
|
- storage,
|
|
|
- "dangvantuan/sentence-camembert-large",
|
|
|
- )
|
|
|
- await worker.run(jobs)
|
|
|
+ if embeddings:
|
|
|
+ self.new_embeddings_event.set()
|
|
|
|
|
|
- logger.info("Embeddings service exiting")
|
|
|
+ await asyncio.sleep(5)
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def compute_embeddings_for(model: Any, sentences: tuple[tuple[int, str]]):
|
|
|
+ logger.debug(f"Computing embeddings for {len(sentences)} sentences")
|
|
|
+ all_texts = [t[1] for t in sentences]
|
|
|
+ all_embeddings = model.encode(all_texts)
|
|
|
|
|
|
-if __name__ == "__main__":
|
|
|
- asyncio.run(main())
|
|
|
+ return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
|