Przeglądaj źródła

Implement embeddings computation & storage

jherve 1 rok temu
rodzic
commit
2c8891a01f

+ 47 - 0
src/de_quoi_parle_le_monde/snapshot_worker.py

@@ -13,6 +13,23 @@ from de_quoi_parle_le_monde.internet_archive import (
 from de_quoi_parle_le_monde.medias import media_collection
 from de_quoi_parle_le_monde.storage import Storage
 
+from itertools import islice
+from collections import defaultdict
+
+
+def batched(iterable, n):
+    """
+    Batch data into tuples of length n. The last batch may be shorter.
+        `batched('ABCDEFG', 3) --> ABC DEF G`
+
+    Straight from : https://docs.python.org/3.11/library/itertools.html#itertools-recipes
+    """
+    if n < 1:
+        raise ValueError("n must be at least one")
+    it = iter(iterable)
+    while batch := tuple(islice(it, n)):
+        yield batch
+
 
 @frozen
 class SnapshotWorker:
@@ -114,6 +131,27 @@ class EmbeddingsWorker:
     storage: Storage
     model: SentenceTransformer
 
+    def compute_embeddings_for(self, sentences: dict[int, str]):
+        logger.debug(f"Computing embeddings for {len(sentences)} sentences")
+        inverted_dict = defaultdict(list)
+        for idx, (k, v) in enumerate(list(sentences.items())):
+            inverted_dict[v].append((idx, k))
+        all_texts = list(inverted_dict.keys())
+        all_embeddings = self.model.encode(all_texts)
+
+        embeddings_by_id = {}
+        for e, text in zip(all_embeddings, all_texts):
+            all_ids = [id for (_, id) in inverted_dict[text]]
+            for i in all_ids:
+                embeddings_by_id[i] = e
+
+        return embeddings_by_id
+
+    async def store_embeddings(self, embeddings_by_id: dict):
+        logger.debug(f"Storing {len(embeddings_by_id)} embeddings")
+        for i, embed in embeddings_by_id.items():
+            await self.storage.add_embedding(i, embed)
+
     @staticmethod
     def create(storage, model_path):
         model = SentenceTransformer(model_path)
@@ -122,4 +160,13 @@ class EmbeddingsWorker:
 
 async def compute_embeddings(storage: Storage):
     worker = EmbeddingsWorker.create(storage, "dangvantuan/sentence-camembert-large")
+    all_snapshots = await storage.list_all_featured_article_snapshots()
+
+    batch_size = 64
+    for batch in batched(all_snapshots, batch_size):
+        embeddings_by_id = worker.compute_embeddings_for(
+            {s["id"]: s["title"] for s in batch}
+        )
+        await worker.store_embeddings(embeddings_by_id)
+
     return worker

+ 26 - 0
src/de_quoi_parle_le_monde/storage.py

@@ -2,6 +2,7 @@ from typing import Any
 import aiosqlite
 import asyncio
 from datetime import datetime
+import numpy as np
 
 from de_quoi_parle_le_monde.article import (
     TopArticle,
@@ -290,6 +291,31 @@ class Storage:
             )
             await conn.commit()
 
+    async def list_all_featured_article_snapshots(self):
+        async with self.conn as conn:
+            rows = await conn.execute_fetchall(
+                f"""
+                    SELECT *
+                    FROM featured_article_snapshots
+                """,
+            )
+
+            return [
+                {"id": r[0], "featured_article_id": r[1], "title": r[2], "url": r[3]}
+                for r in rows
+            ]
+
+    async def add_embedding(self, featured_article_snapshot_id: int, embedding):
+        async with self.conn as conn:
+            await conn.execute_insert(
+                self._insert_stmt(
+                    "articles_embeddings",
+                    ["featured_article_snapshot_id", "title_embedding"],
+                ),
+                [featured_article_snapshot_id, embedding],
+            )
+            await conn.commit()
+
     async def select_from(self, table):
         async with self.conn as conn:
             return await conn.execute_fetchall(