ソースを参照

Add an embeddings watchdog

jherve 1 年間 前
コミット
4b25b17f11
1 ファイル変更61 行追加0 行削除
  1. 61 0
      src/media_observer/test.py

+ 61 - 0
src/media_observer/test.py

@@ -1,4 +1,5 @@
 import asyncio
+from itertools import islice
 from datetime import date, datetime, time, timedelta
 import os
 from pathlib import Path
@@ -283,6 +284,60 @@ class StoreWorker(QueueWorker):
         return {"storage": self.storage}
 
 
+def batched(iterable, n):
+    """
+    Batch data into tuples of length n. The last batch may be shorter.
+        `batched('ABCDEFG', 3) --> ABC DEF G`
+
+    Straight from : https://docs.python.org/3.11/library/itertools.html#itertools-recipes
+    """
+    if n < 1:
+        raise ValueError("n must be at least one")
+    it = iter(iterable)
+    while batch := tuple(islice(it, n)):
+        yield batch
+
+
+@frozen
+class EmbeddingsWorker(Worker):
+    storage: Storage
+    model_name: str
+    batch_size: int
+
+    async def run(self):
+        def load_model():
+            logger.info("Starting stuff")
+            from sentence_transformers import SentenceTransformer
+
+            model = SentenceTransformer(self.model_name)
+            return model
+
+        def compute_embeddings_for(model, sentences: tuple[tuple[int, str]]):
+            logger.debug(f"Computing embeddings for {len(sentences)} sentences")
+            all_texts = [t[1] for t in sentences]
+            all_embeddings = model.encode(all_texts)
+
+            return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
+
+        logger.info("Embeddings watchdog")
+
+        loop = asyncio.get_running_loop()
+        model = await loop.run_in_executor(None, load_model)
+        logger.info("Model loaded")
+        all_titles = [
+            (t["id"], t["text"])
+            for t in await self.storage.list_all_titles_without_embedding()
+        ]
+
+        for batch in batched(all_titles, self.batch_size):
+            embeddings = compute_embeddings_for(model, batch)
+            logger.info(f"embeddings: {embeddings}")
+            for i, embed in embeddings.items():
+                await self.storage.add_embedding(i, embed)
+
+            logger.debug(f"Stored {len(embeddings)} embeddings")
+
+
 @frozen
 class WebServer(Worker):
     async def run(self):
@@ -317,6 +372,11 @@ async def main():
                 "store": StoreWorker(SnapshotStoreJob.queue, None, storage),
             }
             web_server = WebServer()
+            embeds = EmbeddingsWorker(
+                storage,
+                "dangvantuan/sentence-camembert-large",
+                64,
+            )
             async with asyncio.TaskGroup() as tg:
                 for w in workers.values():
                     tasks.append(tg.create_task(w.loop()))
@@ -324,6 +384,7 @@ async def main():
                     await SnapshotSearchJob.queue.put(j)
 
                 tasks.append(tg.create_task(web_server.run()))
+                tasks.append(tg.create_task(embeds.run()))
     finally:
         await storage.close()