Explorar o código

Switch to Annoy for vector search/indexing

jherve hai 1 ano
pai
achega
d51ff03bb2

+ 3 - 0
.gitignore

@@ -163,3 +163,6 @@ cython_debug/
 
 # Ignore dynaconf secret files
 .secrets.*
+
+similarity.index
+similarity.class

+ 2 - 1
pyproject.toml

@@ -21,12 +21,13 @@ dependencies = [
     "hypercorn>=0.16.0",
     "fastapi>=0.110.1",
     "jinja2>=3.1.3",
-    "faiss-cpu>=1.8.0",
     "sentencepiece>=0.2.0",
     "protobuf>=5.26.1",
     "dynaconf>=3.2.5",
     "packaging>=24.0",
     "asyncpg>=0.29.0",
+    "annoy>=1.17.3",
+    "numpy>=1.26.4",
 ]
 readme = "README.md"
 requires-python = ">= 3.11"

+ 3 - 3
requirements-dev.lock

@@ -31,6 +31,8 @@ aiosqlite==0.20.0
     # via aiohttp-client-cache
 annotated-types==0.6.0
     # via pydantic
+annoy==1.17.3
+    # via de-quoi-parle-le-monde
 anyio==4.3.0
     # via httpx
     # via starlette
@@ -72,8 +74,6 @@ dynaconf==3.2.5
     # via de-quoi-parle-le-monde
 email-validator==2.1.1
     # via fastapi
-faiss-cpu==1.8.0
-    # via de-quoi-parle-le-monde
 fastapi==0.111.0
     # via de-quoi-parle-le-monde
     # via fastapi-cli
@@ -131,7 +131,7 @@ multidict==6.0.5
     # via aiohttp
     # via yarl
 numpy==1.26.4
-    # via faiss-cpu
+    # via de-quoi-parle-le-monde
 orjson==3.10.3
     # via fastapi
 packaging==24.0

+ 3 - 3
requirements.lock

@@ -31,6 +31,8 @@ aiosqlite==0.20.0
     # via aiohttp-client-cache
 annotated-types==0.6.0
     # via pydantic
+annoy==1.17.3
+    # via de-quoi-parle-le-monde
 anyio==4.3.0
     # via httpx
     # via starlette
@@ -72,8 +74,6 @@ dynaconf==3.2.5
     # via de-quoi-parle-le-monde
 email-validator==2.1.1
     # via fastapi
-faiss-cpu==1.8.0
-    # via de-quoi-parle-le-monde
 fastapi==0.111.0
     # via de-quoi-parle-le-monde
     # via fastapi-cli
@@ -131,7 +131,7 @@ multidict==6.0.5
     # via aiohttp
     # via yarl
 numpy==1.26.4
-    # via faiss-cpu
+    # via de-quoi-parle-le-monde
 orjson==3.10.3
     # via fastapi
 packaging==24.0

+ 61 - 34
src/de_quoi_parle_le_monde/similarity_index.py

@@ -1,20 +1,26 @@
 import asyncio
-from typing import Callable
+import pickle
+from attrs import define
+from typing import Any, Callable, ClassVar
 from loguru import logger
-import faiss
-import numpy as np
+from annoy import AnnoyIndex
 
 
 from de_quoi_parle_le_monde.storage import Storage
 
 
-class SimilaritySearch:
-    instance = None
+file_path_index = "./similarity.index"
+file_path_pickle_class = "./similarity.class"
+
 
-    def __init__(self, storage) -> None:
-        d = 1024
-        self.storage = storage
-        self.index = faiss.index_factory(d, "IDMap,Flat", faiss.METRIC_INNER_PRODUCT)
+
+@define
+class SimilaritySearch:
+    storage: Storage
+    index: AnnoyIndex
+    embedding_to_featured: dict[int, int] = {}
+    featured_to_embedding: dict[int, int] = {}
+    instance: ClassVar[Any | None] = None
 
     async def add_embeddings(self):
         embeds = await self.storage.list_all_articles_embeddings()
@@ -26,11 +32,12 @@ class SimilaritySearch:
             logger.error(msg)
             raise ValueError(msg)
 
-        all_titles = np.array([e["title_embedding"] for e in embeds])
-        faiss.normalize_L2(all_titles)
-        self.index.add_with_ids(
-            all_titles, [e["featured_article_snapshot_id"] for e in embeds]
-        )
+        for e in embeds:
+            self.index.add_item(e["id"], e["title_embedding"])
+            self.embedding_to_featured[e["id"]] = e["featured_article_snapshot_id"]
+            self.featured_to_embedding[e["featured_article_snapshot_id"]] = e["id"]
+
+        self.index.build(20)
 
     async def search(
         self,
@@ -38,34 +45,56 @@ class SimilaritySearch:
         nb_results: int,
         score_func: Callable[[float], bool],
     ):
-        embeds = await self.storage.get_article_embedding(featured_article_snapshot_ids)
-
-        if (nb_embeds := len(embeds)) != (
-            nb_articles := len(featured_article_snapshot_ids)
-        ):
+        try:
+            [embed_id] = [
+                self.featured_to_embedding[id_] for id_ in featured_article_snapshot_ids
+            ]
+        except KeyError as e:
             msg = (
-                f"Expected {nb_articles} embedding(s) in storage but found only {nb_embeds}. "
+                f"Could not find all embedding(s) in storage for {featured_article_snapshot_ids}. "
                 "A plausible cause is that they have not been computed yet"
             )
             logger.error(msg)
-            raise ValueError(msg)
-
-        all_titles = np.array([e["title_embedding"] for e in embeds])
-        faiss.normalize_L2(all_titles)
-        scores, indices = self.index.search(np.array(all_titles), nb_results)
+            raise e
 
+        indices, distances = self.index.get_nns_by_item(
+            embed_id, nb_results, include_distances=True
+        )
         return [
             (
-                featured_article_snapshot_ids[idx],
-                [(int(i), d) for d, i in res if score_func(d)],
+                embed_id,
+                [
+                    (self.embedding_to_featured[i], d)
+                    for i, d in (zip(indices, distances))
+                    if i != embed_id and score_func(d)
+                ],
             )
-            for idx, res in enumerate(np.dstack((scores, indices)))
         ]
 
     @classmethod
     def create(cls, storage):
         if cls.instance is None:
-            cls.instance = SimilaritySearch(storage)
+            d = 1024
+            index = AnnoyIndex(d, "dot")
+            cls.instance = SimilaritySearch(storage, index)
+
+        return cls.instance
+
+    async def save(self):
+        self.index.save(file_path_index)
+        with open(file_path_pickle_class, "wb") as f:
+            pickle.dump((self.embedding_to_featured, self.featured_to_embedding), f)
+
+    @classmethod
+    def load(cls, storage):
+        if cls.instance is None:
+            d = 1024
+            index = AnnoyIndex(d, "dot")
+            index.load(file_path_index)
+            with open(file_path_pickle_class, "rb") as f:
+                (embedding_to_featured, featured_to_embedding) = pickle.load(f)
+
+            cls.instance = SimilaritySearch(storage, index, embedding_to_featured, featured_to_embedding)
 
         return cls.instance
 
@@ -75,11 +104,9 @@ async def main():
     sim_index = SimilaritySearch.create(storage)
 
     logger.info("Starting index..")
-    try:
-        await sim_index.add_embeddings()
-        logger.info("Similarity index ready")
-    except ValueError:
-        ...
+    await sim_index.add_embeddings()
+    await sim_index.save()
+    logger.info("Similarity index ready")
 
 
 if __name__ == "__main__":

+ 3 - 3
src/de_quoi_parle_le_monde/web.py

@@ -18,7 +18,7 @@ async def get_db():
 
 
 async def get_similarity_search(storage: Storage = Depends(get_db)):
-    return SimilaritySearch.create(storage)
+    return SimilaritySearch.load(storage)
 
 
 @app.get("/", response_class=HTMLResponse)
@@ -62,9 +62,9 @@ async def site_main_article_snapshot(
         [(_, similar)] = await sim_index.search(
             [focused_article_id],
             20,
-            lambda s: s < 1.0 and s >= 0.5,
+            lambda s: s < 100 and s >= 25,
         )
-    except ValueError:
+    except KeyError as e:
         similar = []
 
     similar_by_id = {s[0]: s[1] for s in similar}