jherve 1 rok temu
rodzic
commit
d20760130b

+ 1 - 19
src/media_observer/__main__.py

@@ -18,29 +18,11 @@ from media_observer.snapshots import (
     StoreWorker,
     SnapshotSearchJob,
 )
-from media_observer.similarity_index import SimilaritySearch
+from media_observer.similarity_index import SimilarityIndexWorker
 from media_observer.storage import Storage
 from media_observer.web import app
 
 
-@frozen
-class SimilarityIndexWorker(Worker):
-    storage: Storage
-    new_embeddings_event: asyncio.Event
-
-    async def run(self):
-        while True:
-            await self.new_embeddings_event.wait()
-
-            sim_index = SimilaritySearch.create(self.storage)
-            logger.info("Starting index..")
-            await sim_index.add_embeddings()
-            await sim_index.save()
-            logger.info("Similarity index ready")
-
-            self.new_embeddings_event.clear()
-
-
 @frozen
 class WebServer(Worker):
     async def run(self):

+ 15 - 10
src/media_observer/similarity_index.py

@@ -2,13 +2,14 @@ import asyncio
 import os
 from datetime import datetime
 import pickle
-from attrs import define
+from attrs import define, frozen
 from typing import Any, Callable, ClassVar
 from loguru import logger
 from annoy import AnnoyIndex
 
 
 from media_observer.storage import Storage
+from media_observer.worker import Worker
 
 
 file_path_index = "./similarity.index"
@@ -116,15 +117,19 @@ class SimilaritySearch:
             return SimilaritySearch(storage, index)
 
 
-async def main():
-    storage = await Storage.create()
-    sim_index = SimilaritySearch.create(storage)
+@frozen
+class SimilarityIndexWorker(Worker):
+    storage: Storage
+    new_embeddings_event: asyncio.Event
 
-    logger.info("Starting index..")
-    await sim_index.add_embeddings()
-    await sim_index.save()
-    logger.info("Similarity index ready")
+    async def run(self):
+        while True:
+            await self.new_embeddings_event.wait()
 
+            sim_index = SimilaritySearch.create(self.storage)
+            logger.info("Starting index..")
+            await sim_index.add_embeddings()
+            await sim_index.save()
+            logger.info("Similarity index ready")
 
-if __name__ == "__main__":
-    asyncio.run(main())
+            self.new_embeddings_event.clear()