Przeglądaj źródła

Add SimilaritySearch with faiss

jherve 1 rok temu
rodzic
commit
6ba9348f31

+ 10 - 1
src/de_quoi_parle_le_monde/main.py

@@ -14,6 +14,7 @@ from de_quoi_parle_le_monde.workers.embeddings import (
     EmbeddingsJob,
     EmbeddingsWorker,
 )
+from de_quoi_parle_le_monde.similarity_search import SimilaritySearch
 
 
 @frozen
@@ -22,11 +23,13 @@ class Application:
     storage: Storage
     web_app: Framework
     web_config: Config
+    similarity_index: SimilaritySearch
 
     async def run(self):
         await asyncio.gather(
             self._run_web_server(),
             self._run_snapshot_worker(),
+            self._run_similarity_index(),
             self._run_embeddings_worker(),
         )
         logger.info("Will quit now..")
@@ -55,14 +58,20 @@ class Application:
         )
         await worker.run(jobs)
 
+    async def _run_similarity_index(self):
+        logger.info("Starting index..")
+        await self.similarity_index.add_embeddings()
+        logger.info("Similarity index ready")
+
     @staticmethod
     async def create():
         http_client = HttpClient()
         storage = await Storage.create()
         web_app = app
         web_config = Config()
+        sim_index = SimilaritySearch.create(storage)
 
-        return Application(http_client, storage, web_app, web_config)
+        return Application(http_client, storage, web_app, web_config, sim_index)
 
 
 async def main():

+ 46 - 0
src/de_quoi_parle_le_monde/similarity_search.py

@@ -0,0 +1,46 @@
+from typing import Callable
+import faiss
+import numpy as np
+
+
+class SimilaritySearch:
+    instance = None
+
+    def __init__(self, storage) -> None:
+        d = 1024
+        self.storage = storage
+        self.index = faiss.index_factory(d, "IDMap,Flat", faiss.METRIC_INNER_PRODUCT)
+
+    async def add_embeddings(self):
+        embeds = await self.storage.list_all_articles_embeddings()
+        all_titles = np.array([e["title_embedding"] for e in embeds])
+        faiss.normalize_L2(all_titles)
+        self.index.add_with_ids(
+            all_titles, [e["featured_article_snapshot_id"] for e in embeds]
+        )
+
+    async def search(
+        self,
+        featured_article_snapshot_ids: list[int],
+        nb_results: int,
+        score_func: Callable[[float], bool],
+    ):
+        embeds = await self.storage.get_article_embedding(featured_article_snapshot_ids)
+        all_titles = np.array([e["title_embedding"] for e in embeds])
+        faiss.normalize_L2(all_titles)
+        D, I = self.index.search(np.array(all_titles), nb_results)
+
+        return [
+            (
+                featured_article_snapshot_ids[idx],
+                [(int(i), d) for d, i in res if score_func(d)],
+            )
+            for idx, res in enumerate(np.dstack((D, I)))
+        ]
+
+    @classmethod
+    def create(cls, storage):
+        if cls.instance is None:
+            cls.instance = SimilaritySearch(storage)
+
+        return cls.instance

+ 39 - 0
src/de_quoi_parle_le_monde/storage.py

@@ -316,6 +316,45 @@ class Storage:
 
             return [r[0] for r in rows]
 
+    async def list_all_articles_embeddings(self):
+        async with self.conn as conn:
+            rows = await conn.execute_fetchall(
+                f"""
+                    SELECT *
+                    FROM articles_embeddings
+                """,
+            )
+
+            return [
+                {
+                    "id": r[0],
+                    "featured_article_snapshot_id": r[1],
+                    "title_embedding": np.frombuffer(r[2], dtype="float32"),
+                }
+                for r in rows
+            ]
+
+    async def get_article_embedding(self, featured_article_snapshot_ids: list[int]):
+        async with self.conn as conn:
+            placeholders = ", ".join(["?" for _ in featured_article_snapshot_ids])
+            rows = await conn.execute_fetchall(
+                f"""
+                    SELECT *
+                    FROM articles_embeddings
+                    WHERE featured_article_snapshot_id IN ({placeholders})
+                """,
+                featured_article_snapshot_ids,
+            )
+
+            return [
+                {
+                    "id": r[0],
+                    "featured_article_snapshot_id": r[1],
+                    "title_embedding": np.frombuffer(r[2], dtype="float32"),
+                }
+                for r in rows
+            ]
+
     async def add_embedding(self, featured_article_snapshot_id: int, embedding):
         async with self.conn as conn:
             await conn.execute_insert(

+ 6 - 0
src/de_quoi_parle_le_monde/web.py

@@ -5,6 +5,7 @@ from fastapi.templating import Jinja2Templates
 
 from de_quoi_parle_le_monde.medias import media_collection
 from de_quoi_parle_le_monde.storage import Storage
+from de_quoi_parle_le_monde.similarity_search import SimilaritySearch
 
 
 app = FastAPI()
@@ -16,6 +17,10 @@ async def get_db():
     return await Storage.create()
 
 
+async def get_similarity_search(storage: Storage = Depends(get_db)):
+    return SimilaritySearch.create(storage)
+
+
 @app.get("/", response_class=HTMLResponse)
 async def index(request: Request, storage: Storage = Depends(get_db)):
     sites = await storage.list_sites()
@@ -30,6 +35,7 @@ async def site_main_article_snapshot(
     id: int,
     snapshot_id: int,
     storage: Storage = Depends(get_db),
+    sim_index: SimilaritySearch = Depends(get_similarity_search),
 ):
     def get_article_sibling(after_before_articles, cond_fun):
         return min(