|
|
@@ -23,6 +23,7 @@ from media_observer.internet_archive import (
|
|
|
InternetArchiveSnapshotId,
|
|
|
SnapshotNotYetAvailable,
|
|
|
)
|
|
|
+from media_observer.similarity_index import SimilaritySearch
|
|
|
from media_observer.storage import Storage
|
|
|
from media_observer.medias import media_collection
|
|
|
from media_observer.web import app
|
|
|
@@ -339,6 +340,20 @@ class EmbeddingsWorker(Worker):
|
|
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
|
|
+
|
|
|
+@frozen
|
|
|
+class SimilarityIndexWorker(Worker):
|
|
|
+ storage: Storage
|
|
|
+
|
|
|
+ async def run(self):
|
|
|
+ sim_index = SimilaritySearch.create(self.storage)
|
|
|
+
|
|
|
+ logger.info("Starting index..")
|
|
|
+ await sim_index.add_embeddings()
|
|
|
+ await sim_index.save()
|
|
|
+ logger.info("Similarity index ready")
|
|
|
+
|
|
|
+
|
|
|
@frozen
|
|
|
class WebServer(Worker):
|
|
|
async def run(self):
|
|
|
@@ -378,14 +393,16 @@ async def main():
|
|
|
"dangvantuan/sentence-camembert-large",
|
|
|
64,
|
|
|
)
|
|
|
+ index = SimilarityIndexWorker(storage)
|
|
|
async with asyncio.TaskGroup() as tg:
|
|
|
for w in workers.values():
|
|
|
tasks.append(tg.create_task(w.loop()))
|
|
|
- for j in jobs[:3]:
|
|
|
+ for j in jobs:
|
|
|
await SnapshotSearchJob.queue.put(j)
|
|
|
|
|
|
tasks.append(tg.create_task(web_server.run()))
|
|
|
tasks.append(tg.create_task(embeds.loop()))
|
|
|
+ tasks.append(tg.create_task(index.run()))
|
|
|
finally:
|
|
|
await storage.close()
|
|
|
|