__main__.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import asyncio
  2. import concurrent.futures
  3. from itertools import islice
  4. from typing import Any
  5. from loguru import logger
  6. from attrs import define, field, frozen
  7. from hypercorn.asyncio import serve
  8. from hypercorn.config import Config
  9. from media_observer.worker import Worker
  10. from media_observer.internet_archive import InternetArchiveClient
  11. from media_observer.snapshots import (
  12. FetchWorker,
  13. ParseWorker,
  14. SnapshotWorker,
  15. SnapshotFetchJob,
  16. SnapshotParseJob,
  17. SnapshotStoreJob,
  18. SnapshotWatchdog,
  19. StoreWorker,
  20. SnapshotSearchJob,
  21. )
  22. from media_observer.similarity_index import SimilaritySearch
  23. from media_observer.storage import Storage
  24. from media_observer.web import app
  25. def batched(iterable, n):
  26. """
  27. Batch data into tuples of length n. The last batch may be shorter.
  28. `batched('ABCDEFG', 3) --> ABC DEF G`
  29. Straight from : https://docs.python.org/3.11/library/itertools.html#itertools-recipes
  30. """
  31. if n < 1:
  32. raise ValueError("n must be at least one")
  33. it = iter(iterable)
  34. while batch := tuple(islice(it, n)):
  35. yield batch
  36. @define
  37. class EmbeddingsWorker(Worker):
  38. storage: Storage
  39. model_name: str
  40. batch_size: int
  41. new_embeddings_event: asyncio.Event
  42. model: Any = field(init=False, default=None)
  43. async def run(self):
  44. def load_model():
  45. from sentence_transformers import SentenceTransformer
  46. self.model = SentenceTransformer(self.model_name)
  47. while True:
  48. loop = asyncio.get_running_loop()
  49. if self.model is None:
  50. await loop.run_in_executor(None, load_model)
  51. all_titles = [
  52. (t["id"], t["text"])
  53. for t in await self.storage.list_all_titles_without_embedding()
  54. ]
  55. for batch in batched(all_titles, self.batch_size):
  56. with concurrent.futures.ProcessPoolExecutor(max_workers=1) as pool:
  57. embeddings = await loop.run_in_executor(
  58. pool, self.compute_embeddings_for, self.model, batch
  59. )
  60. for i, embed in embeddings.items():
  61. await self.storage.add_embedding(i, embed)
  62. logger.debug(f"Stored {len(embeddings)} embeddings")
  63. if embeddings:
  64. self.new_embeddings_event.set()
  65. await asyncio.sleep(5)
  66. @staticmethod
  67. def compute_embeddings_for(model: Any, sentences: tuple[tuple[int, str]]):
  68. logger.debug(f"Computing embeddings for {len(sentences)} sentences")
  69. all_texts = [t[1] for t in sentences]
  70. all_embeddings = model.encode(all_texts)
  71. return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
  72. @frozen
  73. class SimilarityIndexWorker(Worker):
  74. storage: Storage
  75. new_embeddings_event: asyncio.Event
  76. async def run(self):
  77. while True:
  78. await self.new_embeddings_event.wait()
  79. sim_index = SimilaritySearch.create(self.storage)
  80. logger.info("Starting index..")
  81. await sim_index.add_embeddings()
  82. await sim_index.save()
  83. logger.info("Similarity index ready")
  84. self.new_embeddings_event.clear()
  85. @frozen
  86. class WebServer(Worker):
  87. async def run(self):
  88. shutdown_event = asyncio.Event()
  89. try:
  90. logger.info("Web server stuff")
  91. # Just setting the shutdown_trigger even though it is not connected
  92. # to anything allows the app to gracefully shutdown
  93. await serve(app, Config(), shutdown_trigger=shutdown_event.wait)
  94. except asyncio.CancelledError:
  95. logger.warning("Web server exiting")
  96. return
  97. @frozen
  98. class MediaObserverApplication:
  99. snapshots_workers: list[Worker]
  100. web_server: WebServer
  101. embeds: EmbeddingsWorker
  102. index: SimilarityIndexWorker
  103. @property
  104. def workers(self):
  105. return self.snapshots_workers + [self.web_server, self.embeds, self.index]
  106. @staticmethod
  107. async def create(storage: Storage, ia: InternetArchiveClient):
  108. new_embeddings_event = asyncio.Event()
  109. new_embeddings_event.set()
  110. snapshots_workers = (
  111. [SnapshotWatchdog(SnapshotSearchJob.queue)]
  112. + [
  113. SnapshotWorker(
  114. SnapshotSearchJob.queue, SnapshotFetchJob.queue, storage, ia
  115. )
  116. ]
  117. * 3
  118. + [FetchWorker(SnapshotFetchJob.queue, SnapshotParseJob.queue, ia)] * 3
  119. + [
  120. ParseWorker(SnapshotParseJob.queue, SnapshotStoreJob.queue),
  121. StoreWorker(SnapshotStoreJob.queue, None, storage),
  122. ]
  123. )
  124. web_server = WebServer()
  125. embeds = EmbeddingsWorker(
  126. storage,
  127. "dangvantuan/sentence-camembert-large",
  128. 64,
  129. new_embeddings_event,
  130. )
  131. index = SimilarityIndexWorker(storage, new_embeddings_event)
  132. return MediaObserverApplication(snapshots_workers, web_server, embeds, index)
  133. async def main():
  134. tasks = []
  135. storage = await Storage.create()
  136. try:
  137. async with InternetArchiveClient.create() as ia:
  138. app = await MediaObserverApplication.create(storage, ia)
  139. async with asyncio.TaskGroup() as tg:
  140. for w in app.workers:
  141. tasks.append(tg.create_task(w.run()))
  142. finally:
  143. await storage.close()
  144. if __name__ == "__main__":
  145. try:
  146. asyncio.run(main())
  147. except KeyboardInterrupt:
  148. logger.warning("Main kbinterrupt")