Przeglądaj źródła

Add a similarity index worker

jherve 1 rok temu
rodzic
commit
9a739012af
1 zmienionych plików z 18 dodań i 1 usunięć
  1. 18 1
      src/media_observer/test.py

+ 18 - 1
src/media_observer/test.py

@@ -23,6 +23,7 @@ from media_observer.internet_archive import (
     InternetArchiveSnapshotId,
     SnapshotNotYetAvailable,
 )
+from media_observer.similarity_index import SimilaritySearch
 from media_observer.storage import Storage
 from media_observer.medias import media_collection
 from media_observer.web import app
@@ -339,6 +340,20 @@ class EmbeddingsWorker(Worker):
 
         await asyncio.sleep(5)
 
+
+@frozen
+class SimilarityIndexWorker(Worker):
+    storage: Storage
+
+    async def run(self):
+        sim_index = SimilaritySearch.create(self.storage)
+
+        logger.info("Starting index..")
+        await sim_index.add_embeddings()
+        await sim_index.save()
+        logger.info("Similarity index ready")
+
+
 @frozen
 class WebServer(Worker):
     async def run(self):
@@ -378,14 +393,16 @@ async def main():
                 "dangvantuan/sentence-camembert-large",
                 64,
             )
+            index = SimilarityIndexWorker(storage)
             async with asyncio.TaskGroup() as tg:
                 for w in workers.values():
                     tasks.append(tg.create_task(w.loop()))
-                for j in jobs[:3]:
+                for j in jobs:
                     await SnapshotSearchJob.queue.put(j)
 
                 tasks.append(tg.create_task(web_server.run()))
                 tasks.append(tg.create_task(embeds.loop()))
+                tasks.append(tg.create_task(index.run()))
     finally:
         await storage.close()