jherve 1 рік тому
батько
коміт
e820826df5
2 змінених файлів з 42 додано та 132 видалено
  1. 2 66
      src/media_observer/__main__.py
  2. 40 66
      src/media_observer/embeddings.py

+ 2 - 66
src/media_observer/__main__.py

@@ -1,13 +1,11 @@
 import asyncio
-import concurrent.futures
-from itertools import islice
-from typing import Any
 from loguru import logger
-from attrs import define, field, frozen
+from attrs import frozen
 from hypercorn.asyncio import serve
 from hypercorn.config import Config
 
 from media_observer.worker import Worker
+from media_observer.embeddings import EmbeddingsWorker
 from media_observer.internet_archive import InternetArchiveClient
 from media_observer.snapshots import (
     FetchWorker,
@@ -25,68 +23,6 @@ from media_observer.storage import Storage
 from media_observer.web import app
 
 
-def batched(iterable, n):
-    """
-    Batch data into tuples of length n. The last batch may be shorter.
-        `batched('ABCDEFG', 3) --> ABC DEF G`
-
-    Straight from : https://docs.python.org/3.11/library/itertools.html#itertools-recipes
-    """
-    if n < 1:
-        raise ValueError("n must be at least one")
-    it = iter(iterable)
-    while batch := tuple(islice(it, n)):
-        yield batch
-
-
-@define
-class EmbeddingsWorker(Worker):
-    storage: Storage
-    model_name: str
-    batch_size: int
-    new_embeddings_event: asyncio.Event
-    model: Any = field(init=False, default=None)
-
-    async def run(self):
-        def load_model():
-            from sentence_transformers import SentenceTransformer
-
-            self.model = SentenceTransformer(self.model_name)
-
-        while True:
-            loop = asyncio.get_running_loop()
-            if self.model is None:
-                await loop.run_in_executor(None, load_model)
-
-            all_titles = [
-                (t["id"], t["text"])
-                for t in await self.storage.list_all_titles_without_embedding()
-            ]
-
-            for batch in batched(all_titles, self.batch_size):
-                with concurrent.futures.ProcessPoolExecutor(max_workers=1) as pool:
-                    embeddings = await loop.run_in_executor(
-                        pool, self.compute_embeddings_for, self.model, batch
-                    )
-                for i, embed in embeddings.items():
-                    await self.storage.add_embedding(i, embed)
-
-                logger.debug(f"Stored {len(embeddings)} embeddings")
-
-                if embeddings:
-                    self.new_embeddings_event.set()
-
-            await asyncio.sleep(5)
-
-    @staticmethod
-    def compute_embeddings_for(model: Any, sentences: tuple[tuple[int, str]]):
-        logger.debug(f"Computing embeddings for {len(sentences)} sentences")
-        all_texts = [t[1] for t in sentences]
-        all_embeddings = model.encode(all_texts)
-
-        return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
-
-
 @frozen
 class SimilarityIndexWorker(Worker):
     storage: Storage

+ 40 - 66
src/media_observer/embeddings.py

@@ -1,12 +1,11 @@
 import asyncio
-from loguru import logger
+import concurrent.futures
 from itertools import islice
-from collections import defaultdict
 from typing import Any
-from attrs import frozen
-from numpy.typing import NDArray
-from sentence_transformers import SentenceTransformer
+from loguru import logger
+from attrs import define, field
 
+from media_observer.worker import Worker
 from media_observer.storage import Storage
 
 
@@ -24,74 +23,49 @@ def batched(iterable, n):
         yield batch
 
 
-@frozen
-class EmbeddingsJob:
-    title_id: int
-    text: NDArray
-
-    @staticmethod
-    async def create(storage: Storage):
-        all_titles = await storage.list_all_titles_without_embedding()
-        return [EmbeddingsJob(t["id"], t["text"]) for t in all_titles]
+@define
+class EmbeddingsWorker(Worker):
+    storage: Storage
+    model_name: str
+    batch_size: int
+    new_embeddings_event: asyncio.Event
+    model: Any = field(init=False, default=None)
 
+    async def run(self):
+        def load_model():
+            from sentence_transformers import SentenceTransformer
 
-@frozen
-class EmbeddingsWorker:
-    storage: Storage
-    model: Any
+            self.model = SentenceTransformer(self.model_name)
 
-    def compute_embeddings_for(self, sentences: dict[int, str]):
-        logger.debug(f"Computing embeddings for {len(sentences)} sentences")
-        inverted_dict = defaultdict(list)
-        for idx, (k, v) in enumerate(list(sentences.items())):
-            inverted_dict[v].append((idx, k))
-        all_texts = list(inverted_dict.keys())
-        all_embeddings = self.model.encode(all_texts)
-
-        embeddings_by_id = {}
-        for e, text in zip(all_embeddings, all_texts):
-            all_ids = [id for (_, id) in inverted_dict[text]]
-            for i in all_ids:
-                embeddings_by_id[i] = e
-
-        return embeddings_by_id
-
-    async def store_embeddings(self, embeddings_by_id: dict):
-        logger.debug(f"Storing {len(embeddings_by_id)} embeddings")
-        for i, embed in embeddings_by_id.items():
-            await self.storage.add_embedding(i, embed)
-
-    async def run(self, jobs: list[EmbeddingsJob]):
-        batch_size = 64
-        for batch in batched(jobs, batch_size):
-            embeddings_by_id = self.compute_embeddings_for(
-                {j.title_id: j.text for j in batch}
-            )
-            await self.store_embeddings(embeddings_by_id)
+        while True:
+            loop = asyncio.get_running_loop()
+            if self.model is None:
+                await loop.run_in_executor(None, load_model)
 
-    @staticmethod
-    def create(storage, model_path):
-        model = SentenceTransformer(model_path)
-        return EmbeddingsWorker(storage, model)
+            all_titles = [
+                (t["id"], t["text"])
+                for t in await self.storage.list_all_titles_without_embedding()
+            ]
 
+            for batch in batched(all_titles, self.batch_size):
+                with concurrent.futures.ProcessPoolExecutor(max_workers=1) as pool:
+                    embeddings = await loop.run_in_executor(
+                        pool, self.compute_embeddings_for, self.model, batch
+                    )
+                for i, embed in embeddings.items():
+                    await self.storage.add_embedding(i, embed)
 
-async def main():
-    storage = await Storage.create()
+                logger.debug(f"Stored {len(embeddings)} embeddings")
 
-    logger.info("Starting embeddings service..")
-    jobs = await EmbeddingsJob.create(storage)
-    if jobs:
-        loop = asyncio.get_event_loop()
-        worker = await loop.run_in_executor(
-            None,
-            EmbeddingsWorker.create,
-            storage,
-            "dangvantuan/sentence-camembert-large",
-        )
-        await worker.run(jobs)
+                if embeddings:
+                    self.new_embeddings_event.set()
 
-    logger.info("Embeddings service exiting")
+            await asyncio.sleep(5)
 
+    @staticmethod
+    def compute_embeddings_for(model: Any, sentences: tuple[tuple[int, str]]):
+        logger.debug(f"Computing embeddings for {len(sentences)} sentences")
+        all_texts = [t[1] for t in sentences]
+        all_embeddings = model.encode(all_texts)
 
-if __name__ == "__main__":
-    asyncio.run(main())
+        return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}