| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- import asyncio
- import concurrent.futures
- from itertools import islice
- from typing import Any
- from loguru import logger
- from attrs import define, field, frozen
- from hypercorn.asyncio import serve
- from hypercorn.config import Config
- from media_observer.worker import Worker
- from media_observer.internet_archive import InternetArchiveClient
- from media_observer.snapshots import (
- FetchWorker,
- ParseWorker,
- SnapshotWorker,
- SnapshotFetchJob,
- SnapshotParseJob,
- SnapshotStoreJob,
- SnapshotWatchdog,
- StoreWorker,
- SnapshotSearchJob,
- )
- from media_observer.similarity_index import SimilaritySearch
- from media_observer.storage import Storage
- from media_observer.web import app
- def batched(iterable, n):
- """
- Batch data into tuples of length n. The last batch may be shorter.
- `batched('ABCDEFG', 3) --> ABC DEF G`
- Straight from : https://docs.python.org/3.11/library/itertools.html#itertools-recipes
- """
- if n < 1:
- raise ValueError("n must be at least one")
- it = iter(iterable)
- while batch := tuple(islice(it, n)):
- yield batch
- @define
- class EmbeddingsWorker(Worker):
- storage: Storage
- model_name: str
- batch_size: int
- new_embeddings_event: asyncio.Event
- model: Any = field(init=False, default=None)
- async def run(self):
- def load_model():
- from sentence_transformers import SentenceTransformer
- self.model = SentenceTransformer(self.model_name)
- while True:
- loop = asyncio.get_running_loop()
- if self.model is None:
- await loop.run_in_executor(None, load_model)
- all_titles = [
- (t["id"], t["text"])
- for t in await self.storage.list_all_titles_without_embedding()
- ]
- for batch in batched(all_titles, self.batch_size):
- with concurrent.futures.ProcessPoolExecutor(max_workers=1) as pool:
- embeddings = await loop.run_in_executor(
- pool, self.compute_embeddings_for, self.model, batch
- )
- for i, embed in embeddings.items():
- await self.storage.add_embedding(i, embed)
- logger.debug(f"Stored {len(embeddings)} embeddings")
- if embeddings:
- self.new_embeddings_event.set()
- await asyncio.sleep(5)
- @staticmethod
- def compute_embeddings_for(model: Any, sentences: tuple[tuple[int, str]]):
- logger.debug(f"Computing embeddings for {len(sentences)} sentences")
- all_texts = [t[1] for t in sentences]
- all_embeddings = model.encode(all_texts)
- return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
- @frozen
- class SimilarityIndexWorker(Worker):
- storage: Storage
- new_embeddings_event: asyncio.Event
- async def run(self):
- while True:
- await self.new_embeddings_event.wait()
- sim_index = SimilaritySearch.create(self.storage)
- logger.info("Starting index..")
- await sim_index.add_embeddings()
- await sim_index.save()
- logger.info("Similarity index ready")
- self.new_embeddings_event.clear()
- @frozen
- class WebServer(Worker):
- async def run(self):
- shutdown_event = asyncio.Event()
- try:
- logger.info("Web server stuff")
- # Just setting the shutdown_trigger even though it is not connected
- # to anything allows the app to gracefully shutdown
- await serve(app, Config(), shutdown_trigger=shutdown_event.wait)
- except asyncio.CancelledError:
- logger.warning("Web server exiting")
- return
- @frozen
- class MediaObserverApplication:
- snapshots_workers: list[Worker]
- web_server: WebServer
- embeds: EmbeddingsWorker
- index: SimilarityIndexWorker
- @property
- def workers(self):
- return self.snapshots_workers + [self.web_server, self.embeds, self.index]
- @staticmethod
- async def create(storage: Storage, ia: InternetArchiveClient):
- new_embeddings_event = asyncio.Event()
- new_embeddings_event.set()
- snapshots_workers = (
- [SnapshotWatchdog(SnapshotSearchJob.queue)]
- + [
- SnapshotWorker(
- SnapshotSearchJob.queue, SnapshotFetchJob.queue, storage, ia
- )
- ]
- * 3
- + [FetchWorker(SnapshotFetchJob.queue, SnapshotParseJob.queue, ia)] * 3
- + [
- ParseWorker(SnapshotParseJob.queue, SnapshotStoreJob.queue),
- StoreWorker(SnapshotStoreJob.queue, None, storage),
- ]
- )
- web_server = WebServer()
- embeds = EmbeddingsWorker(
- storage,
- "dangvantuan/sentence-camembert-large",
- 64,
- new_embeddings_event,
- )
- index = SimilarityIndexWorker(storage, new_embeddings_event)
- return MediaObserverApplication(snapshots_workers, web_server, embeds, index)
- async def main():
- tasks = []
- storage = await Storage.create()
- try:
- async with InternetArchiveClient.create() as ia:
- app = await MediaObserverApplication.create(storage, ia)
- async with asyncio.TaskGroup() as tg:
- for w in app.workers:
- tasks.append(tg.create_task(w.run()))
- finally:
- await storage.close()
- if __name__ == "__main__":
- try:
- asyncio.run(main())
- except KeyboardInterrupt:
- logger.warning("Main kbinterrupt")
|