| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- import asyncio
- from loguru import logger
- from itertools import islice
- from collections import defaultdict
- from typing import Any
- from attrs import frozen
- from numpy.typing import NDArray
- from sentence_transformers import SentenceTransformer
- from media_observer.storage import Storage
- def batched(iterable, n):
- """
- Batch data into tuples of length n. The last batch may be shorter.
- `batched('ABCDEFG', 3) --> ABC DEF G`
- Straight from : https://docs.python.org/3.11/library/itertools.html#itertools-recipes
- """
- if n < 1:
- raise ValueError("n must be at least one")
- it = iter(iterable)
- while batch := tuple(islice(it, n)):
- yield batch
- @frozen
- class EmbeddingsJob:
- title_id: int
- text: NDArray
- @staticmethod
- async def create(storage: Storage):
- all_titles = await storage.list_all_titles_without_embedding()
- return [EmbeddingsJob(t["id"], t["text"]) for t in all_titles]
- @frozen
- class EmbeddingsWorker:
- storage: Storage
- model: Any
- def compute_embeddings_for(self, sentences: dict[int, str]):
- logger.debug(f"Computing embeddings for {len(sentences)} sentences")
- inverted_dict = defaultdict(list)
- for idx, (k, v) in enumerate(list(sentences.items())):
- inverted_dict[v].append((idx, k))
- all_texts = list(inverted_dict.keys())
- all_embeddings = self.model.encode(all_texts)
- embeddings_by_id = {}
- for e, text in zip(all_embeddings, all_texts):
- all_ids = [id for (_, id) in inverted_dict[text]]
- for i in all_ids:
- embeddings_by_id[i] = e
- return embeddings_by_id
- async def store_embeddings(self, embeddings_by_id: dict):
- logger.debug(f"Storing {len(embeddings_by_id)} embeddings")
- for i, embed in embeddings_by_id.items():
- await self.storage.add_embedding(i, embed)
- async def run(self, jobs: list[EmbeddingsJob]):
- batch_size = 64
- for batch in batched(jobs, batch_size):
- embeddings_by_id = self.compute_embeddings_for(
- {j.title_id: j.text for j in batch}
- )
- await self.store_embeddings(embeddings_by_id)
- @staticmethod
- def create(storage, model_path):
- model = SentenceTransformer(model_path)
- return EmbeddingsWorker(storage, model)
- async def main():
- storage = await Storage.create()
- logger.info("Starting embeddings service..")
- jobs = await EmbeddingsJob.create(storage)
- if jobs:
- loop = asyncio.get_event_loop()
- worker = await loop.run_in_executor(
- None,
- EmbeddingsWorker.create,
- storage,
- "dangvantuan/sentence-camembert-large",
- )
- await worker.run(jobs)
- logger.info("Embeddings service exiting")
- if __name__ == "__main__":
- asyncio.run(main())
|