|
|
@@ -298,14 +298,12 @@ def batched(iterable, n):
|
|
|
yield batch
|
|
|
|
|
|
|
|
|
-new_embeddings_event = asyncio.Event()
|
|
|
-
|
|
|
-
|
|
|
@frozen
|
|
|
class EmbeddingsWorker(Worker):
|
|
|
storage: Storage
|
|
|
model_name: str
|
|
|
batch_size: int
|
|
|
+ new_embeddings_event: asyncio.Event
|
|
|
model: Any = field(init=False, default=None)
|
|
|
|
|
|
async def run(self):
|
|
|
@@ -341,7 +339,7 @@ class EmbeddingsWorker(Worker):
|
|
|
logger.debug(f"Stored {len(embeddings)} embeddings")
|
|
|
|
|
|
if embeddings:
|
|
|
- new_embeddings_event.set()
|
|
|
+ self.new_embeddings_event.set()
|
|
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
|
|
@@ -349,9 +347,10 @@ class EmbeddingsWorker(Worker):
|
|
|
@frozen
|
|
|
class SimilarityIndexWorker(Worker):
|
|
|
storage: Storage
|
|
|
+ new_embeddings_event: asyncio.Event
|
|
|
|
|
|
async def run(self):
|
|
|
- await new_embeddings_event.wait()
|
|
|
+ await self.new_embeddings_event.wait()
|
|
|
|
|
|
sim_index = SimilaritySearch.create(self.storage)
|
|
|
logger.info("Starting index..")
|
|
|
@@ -359,7 +358,7 @@ class SimilarityIndexWorker(Worker):
|
|
|
await sim_index.save()
|
|
|
logger.info("Similarity index ready")
|
|
|
|
|
|
- new_embeddings_event.clear()
|
|
|
+ self.new_embeddings_event.clear()
|
|
|
|
|
|
|
|
|
@frozen
|
|
|
@@ -377,41 +376,55 @@ class WebServer(Worker):
|
|
|
return
|
|
|
|
|
|
|
|
|
+@frozen
|
|
|
+class MediaObserverApplication:
|
|
|
+ workers: dict[str, Worker]
|
|
|
+ web_server: WebServer
|
|
|
+ embeds: EmbeddingsWorker
|
|
|
+ index: SimilarityIndexWorker
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ async def create(storage: Storage, ia: InternetArchiveClient):
|
|
|
+ new_embeddings_event = asyncio.Event()
|
|
|
+ new_embeddings_event.set()
|
|
|
+
|
|
|
+ workers = {
|
|
|
+ "snapshot": SnapshotWorker(
|
|
|
+ SnapshotSearchJob.queue, SnapshotFetchJob.queue, storage, ia
|
|
|
+ ),
|
|
|
+ "fetch": FetchWorker(SnapshotFetchJob.queue, SnapshotParseJob.queue, ia),
|
|
|
+ "parse": ParseWorker(SnapshotParseJob.queue, SnapshotStoreJob.queue),
|
|
|
+ "store": StoreWorker(SnapshotStoreJob.queue, None, storage),
|
|
|
+ }
|
|
|
+ web_server = WebServer()
|
|
|
+ embeds = EmbeddingsWorker(
|
|
|
+ storage,
|
|
|
+ "dangvantuan/sentence-camembert-large",
|
|
|
+ 64,
|
|
|
+ new_embeddings_event,
|
|
|
+ )
|
|
|
+ index = SimilarityIndexWorker(storage, new_embeddings_event)
|
|
|
+ return MediaObserverApplication(workers, web_server, embeds, index)
|
|
|
+
|
|
|
+
|
|
|
async def main():
|
|
|
tasks = []
|
|
|
jobs = SnapshotSearchJob.create(
|
|
|
settings.snapshots.days_in_past, settings.snapshots.hours
|
|
|
)
|
|
|
storage = await Storage.create()
|
|
|
- new_embeddings_event.set()
|
|
|
+
|
|
|
try:
|
|
|
async with InternetArchiveClient.create() as ia:
|
|
|
- workers = {
|
|
|
- "snapshot": SnapshotWorker(
|
|
|
- SnapshotSearchJob.queue, SnapshotFetchJob.queue, storage, ia
|
|
|
- ),
|
|
|
- "fetch": FetchWorker(
|
|
|
- SnapshotFetchJob.queue, SnapshotParseJob.queue, ia
|
|
|
- ),
|
|
|
- "parse": ParseWorker(SnapshotParseJob.queue, SnapshotStoreJob.queue),
|
|
|
- "store": StoreWorker(SnapshotStoreJob.queue, None, storage),
|
|
|
- }
|
|
|
- web_server = WebServer()
|
|
|
- embeds = EmbeddingsWorker(
|
|
|
- storage,
|
|
|
- "dangvantuan/sentence-camembert-large",
|
|
|
- 64,
|
|
|
- )
|
|
|
- index = SimilarityIndexWorker(storage)
|
|
|
+ app = await MediaObserverApplication.create(storage, ia)
|
|
|
+
|
|
|
async with asyncio.TaskGroup() as tg:
|
|
|
- for w in workers.values():
|
|
|
+ for w in list(app.workers.values()) + [app.embeds, app.index]:
|
|
|
tasks.append(tg.create_task(w.loop()))
|
|
|
+ tasks.append(tg.create_task(app.web_server.run()))
|
|
|
for j in jobs:
|
|
|
await SnapshotSearchJob.queue.put(j)
|
|
|
|
|
|
- tasks.append(tg.create_task(web_server.run()))
|
|
|
- tasks.append(tg.create_task(embeds.loop()))
|
|
|
- tasks.append(tg.create_task(index.loop()))
|
|
|
finally:
|
|
|
await storage.close()
|
|
|
|