Jelajahi Sumber

Add an abstraction for application

jherve 1 tahun lalu
induk
melakukan
2ec883babb
1 mengubah file dengan 41 tambahan dan 28 penghapusan
  1. 41 28
      src/media_observer/test.py

+ 41 - 28
src/media_observer/test.py

@@ -298,14 +298,12 @@ def batched(iterable, n):
         yield batch
 
 
-new_embeddings_event = asyncio.Event()
-
-
 @frozen
 class EmbeddingsWorker(Worker):
     storage: Storage
     model_name: str
     batch_size: int
+    new_embeddings_event: asyncio.Event
     model: Any = field(init=False, default=None)
 
     async def run(self):
@@ -341,7 +339,7 @@ class EmbeddingsWorker(Worker):
             logger.debug(f"Stored {len(embeddings)} embeddings")
 
             if embeddings:
-                new_embeddings_event.set()
+                self.new_embeddings_event.set()
 
         await asyncio.sleep(5)
 
@@ -349,9 +347,10 @@ class EmbeddingsWorker(Worker):
 @frozen
 class SimilarityIndexWorker(Worker):
     storage: Storage
+    new_embeddings_event: asyncio.Event
 
     async def run(self):
-        await new_embeddings_event.wait()
+        await self.new_embeddings_event.wait()
 
         sim_index = SimilaritySearch.create(self.storage)
         logger.info("Starting index..")
@@ -359,7 +358,7 @@ class SimilarityIndexWorker(Worker):
         await sim_index.save()
         logger.info("Similarity index ready")
 
-        new_embeddings_event.clear()
+        self.new_embeddings_event.clear()
 
 
 @frozen
@@ -377,41 +376,55 @@ class WebServer(Worker):
             return
 
 
+@frozen
+class MediaObserverApplication:
+    workers: dict[str, Worker]
+    web_server: WebServer
+    embeds: EmbeddingsWorker
+    index: SimilarityIndexWorker
+
+    @staticmethod
+    async def create(storage: Storage, ia: InternetArchiveClient):
+        new_embeddings_event = asyncio.Event()
+        new_embeddings_event.set()
+
+        workers = {
+            "snapshot": SnapshotWorker(
+                SnapshotSearchJob.queue, SnapshotFetchJob.queue, storage, ia
+            ),
+            "fetch": FetchWorker(SnapshotFetchJob.queue, SnapshotParseJob.queue, ia),
+            "parse": ParseWorker(SnapshotParseJob.queue, SnapshotStoreJob.queue),
+            "store": StoreWorker(SnapshotStoreJob.queue, None, storage),
+        }
+        web_server = WebServer()
+        embeds = EmbeddingsWorker(
+            storage,
+            "dangvantuan/sentence-camembert-large",
+            64,
+            new_embeddings_event,
+        )
+        index = SimilarityIndexWorker(storage, new_embeddings_event)
+        return MediaObserverApplication(workers, web_server, embeds, index)
+
+
 async def main():
     tasks = []
     jobs = SnapshotSearchJob.create(
         settings.snapshots.days_in_past, settings.snapshots.hours
     )
     storage = await Storage.create()
-    new_embeddings_event.set()
+
     try:
         async with InternetArchiveClient.create() as ia:
-            workers = {
-                "snapshot": SnapshotWorker(
-                    SnapshotSearchJob.queue, SnapshotFetchJob.queue, storage, ia
-                ),
-                "fetch": FetchWorker(
-                    SnapshotFetchJob.queue, SnapshotParseJob.queue, ia
-                ),
-                "parse": ParseWorker(SnapshotParseJob.queue, SnapshotStoreJob.queue),
-                "store": StoreWorker(SnapshotStoreJob.queue, None, storage),
-            }
-            web_server = WebServer()
-            embeds = EmbeddingsWorker(
-                storage,
-                "dangvantuan/sentence-camembert-large",
-                64,
-            )
-            index = SimilarityIndexWorker(storage)
+            app = await MediaObserverApplication.create(storage, ia)
+
             async with asyncio.TaskGroup() as tg:
-                for w in workers.values():
+                for w in list(app.workers.values()) + [app.embeds, app.index]:
                     tasks.append(tg.create_task(w.loop()))
+                tasks.append(tg.create_task(app.web_server.run()))
                 for j in jobs:
                     await SnapshotSearchJob.queue.put(j)
 
-                tasks.append(tg.create_task(web_server.run()))
-                tasks.append(tg.create_task(embeds.loop()))
-                tasks.append(tg.create_task(index.loop()))
     finally:
         await storage.close()