embeddings.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import asyncio
  2. from loguru import logger
  3. from itertools import islice
  4. from collections import defaultdict
  5. from typing import Any
  6. from attrs import frozen
  7. from numpy.typing import NDArray
  8. from sentence_transformers import SentenceTransformer
  9. from media_observer.storage import Storage
  10. def batched(iterable, n):
  11. """
  12. Batch data into tuples of length n. The last batch may be shorter.
  13. `batched('ABCDEFG', 3) --> ABC DEF G`
  14. Straight from : https://docs.python.org/3.11/library/itertools.html#itertools-recipes
  15. """
  16. if n < 1:
  17. raise ValueError("n must be at least one")
  18. it = iter(iterable)
  19. while batch := tuple(islice(it, n)):
  20. yield batch
  21. @frozen
  22. class EmbeddingsJob:
  23. title_id: int
  24. text: NDArray
  25. @staticmethod
  26. async def create(storage: Storage):
  27. all_titles = await storage.list_all_titles_without_embedding()
  28. return [EmbeddingsJob(t["id"], t["text"]) for t in all_titles]
  29. @frozen
  30. class EmbeddingsWorker:
  31. storage: Storage
  32. model: Any
  33. def compute_embeddings_for(self, sentences: dict[int, str]):
  34. logger.debug(f"Computing embeddings for {len(sentences)} sentences")
  35. inverted_dict = defaultdict(list)
  36. for idx, (k, v) in enumerate(list(sentences.items())):
  37. inverted_dict[v].append((idx, k))
  38. all_texts = list(inverted_dict.keys())
  39. all_embeddings = self.model.encode(all_texts)
  40. embeddings_by_id = {}
  41. for e, text in zip(all_embeddings, all_texts):
  42. all_ids = [id for (_, id) in inverted_dict[text]]
  43. for i in all_ids:
  44. embeddings_by_id[i] = e
  45. return embeddings_by_id
  46. async def store_embeddings(self, embeddings_by_id: dict):
  47. logger.debug(f"Storing {len(embeddings_by_id)} embeddings")
  48. for i, embed in embeddings_by_id.items():
  49. await self.storage.add_embedding(i, embed)
  50. async def run(self, jobs: list[EmbeddingsJob]):
  51. batch_size = 64
  52. for batch in batched(jobs, batch_size):
  53. embeddings_by_id = self.compute_embeddings_for(
  54. {j.title_id: j.text for j in batch}
  55. )
  56. await self.store_embeddings(embeddings_by_id)
  57. @staticmethod
  58. def create(storage, model_path):
  59. model = SentenceTransformer(model_path)
  60. return EmbeddingsWorker(storage, model)
  61. async def main():
  62. storage = await Storage.create()
  63. logger.info("Starting embeddings service..")
  64. jobs = await EmbeddingsJob.create(storage)
  65. if jobs:
  66. loop = asyncio.get_event_loop()
  67. worker = await loop.run_in_executor(
  68. None,
  69. EmbeddingsWorker.create,
  70. storage,
  71. "dangvantuan/sentence-camembert-large",
  72. )
  73. await worker.run(jobs)
  74. logger.info("Embeddings service exiting")
  75. if __name__ == "__main__":
  76. asyncio.run(main())