|
|
@@ -1,20 +1,26 @@
|
|
|
import asyncio
|
|
|
-from typing import Callable
|
|
|
+import pickle
|
|
|
+from attrs import define
|
|
|
+from typing import Any, Callable, ClassVar
|
|
|
from loguru import logger
|
|
|
-import faiss
|
|
|
-import numpy as np
|
|
|
+from annoy import AnnoyIndex
|
|
|
|
|
|
|
|
|
from de_quoi_parle_le_monde.storage import Storage
|
|
|
|
|
|
|
|
|
-class SimilaritySearch:
|
|
|
- instance = None
|
|
|
+file_path_index = "./similarity.index"
|
|
|
+file_path_pickle_class = "./similarity.class"
|
|
|
+
|
|
|
|
|
|
- def __init__(self, storage) -> None:
|
|
|
- d = 1024
|
|
|
- self.storage = storage
|
|
|
- self.index = faiss.index_factory(d, "IDMap,Flat", faiss.METRIC_INNER_PRODUCT)
|
|
|
+
|
|
|
+@define
|
|
|
+class SimilaritySearch:
|
|
|
+ storage: Storage
|
|
|
+ index: AnnoyIndex
|
|
|
+ embedding_to_featured: dict[int, int] = {}
|
|
|
+ featured_to_embedding: dict[int, int] = {}
|
|
|
+ instance: ClassVar[Any | None] = None
|
|
|
|
|
|
async def add_embeddings(self):
|
|
|
embeds = await self.storage.list_all_articles_embeddings()
|
|
|
@@ -26,11 +32,12 @@ class SimilaritySearch:
|
|
|
logger.error(msg)
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
- all_titles = np.array([e["title_embedding"] for e in embeds])
|
|
|
- faiss.normalize_L2(all_titles)
|
|
|
- self.index.add_with_ids(
|
|
|
- all_titles, [e["featured_article_snapshot_id"] for e in embeds]
|
|
|
- )
|
|
|
+ for e in embeds:
|
|
|
+ self.index.add_item(e["id"], e["title_embedding"])
|
|
|
+ self.embedding_to_featured[e["id"]] = e["featured_article_snapshot_id"]
|
|
|
+ self.featured_to_embedding[e["featured_article_snapshot_id"]] = e["id"]
|
|
|
+
|
|
|
+ self.index.build(20)
|
|
|
|
|
|
async def search(
|
|
|
self,
|
|
|
@@ -38,34 +45,56 @@ class SimilaritySearch:
|
|
|
nb_results: int,
|
|
|
score_func: Callable[[float], bool],
|
|
|
):
|
|
|
- embeds = await self.storage.get_article_embedding(featured_article_snapshot_ids)
|
|
|
-
|
|
|
- if (nb_embeds := len(embeds)) != (
|
|
|
- nb_articles := len(featured_article_snapshot_ids)
|
|
|
- ):
|
|
|
+ try:
|
|
|
+ [embed_id] = [
|
|
|
+ self.featured_to_embedding[id_] for id_ in featured_article_snapshot_ids
|
|
|
+ ]
|
|
|
+ except KeyError as e:
|
|
|
msg = (
|
|
|
- f"Expected {nb_articles} embedding(s) in storage but found only {nb_embeds}. "
|
|
|
+ f"Could not find all embedding(s) in storage for {featured_article_snapshot_ids}. "
|
|
|
"A plausible cause is that they have not been computed yet"
|
|
|
)
|
|
|
logger.error(msg)
|
|
|
- raise ValueError(msg)
|
|
|
-
|
|
|
- all_titles = np.array([e["title_embedding"] for e in embeds])
|
|
|
- faiss.normalize_L2(all_titles)
|
|
|
- scores, indices = self.index.search(np.array(all_titles), nb_results)
|
|
|
+ raise e
|
|
|
|
|
|
+ indices, distances = self.index.get_nns_by_item(
|
|
|
+ embed_id, nb_results, include_distances=True
|
|
|
+ )
|
|
|
return [
|
|
|
(
|
|
|
- featured_article_snapshot_ids[idx],
|
|
|
- [(int(i), d) for d, i in res if score_func(d)],
|
|
|
+ embed_id,
|
|
|
+ [
|
|
|
+ (self.embedding_to_featured[i], d)
|
|
|
+ for i, d in (zip(indices, distances))
|
|
|
+ if i != embed_id and score_func(d)
|
|
|
+ ],
|
|
|
)
|
|
|
- for idx, res in enumerate(np.dstack((scores, indices)))
|
|
|
]
|
|
|
|
|
|
@classmethod
|
|
|
def create(cls, storage):
|
|
|
if cls.instance is None:
|
|
|
- cls.instance = SimilaritySearch(storage)
|
|
|
+ d = 1024
|
|
|
+ index = AnnoyIndex(d, "dot")
|
|
|
+ cls.instance = SimilaritySearch(storage, index)
|
|
|
+
|
|
|
+ return cls.instance
|
|
|
+
|
|
|
+ async def save(self):
|
|
|
+ self.index.save(file_path_index)
|
|
|
+ with open(file_path_pickle_class, "wb") as f:
|
|
|
+ pickle.dump((self.embedding_to_featured, self.featured_to_embedding), f)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def load(cls, storage):
|
|
|
+ if cls.instance is None:
|
|
|
+ d = 1024
|
|
|
+ index = AnnoyIndex(d, "dot")
|
|
|
+ index.load(file_path_index)
|
|
|
+ with open(file_path_pickle_class, "rb") as f:
|
|
|
+ (embedding_to_featured, featured_to_embedding) = pickle.load(f)
|
|
|
+
|
|
|
+ cls.instance = SimilaritySearch(storage, index, embedding_to_featured, featured_to_embedding)
|
|
|
|
|
|
return cls.instance
|
|
|
|
|
|
@@ -75,11 +104,9 @@ async def main():
|
|
|
sim_index = SimilaritySearch.create(storage)
|
|
|
|
|
|
logger.info("Starting index..")
|
|
|
- try:
|
|
|
- await sim_index.add_embeddings()
|
|
|
- logger.info("Similarity index ready")
|
|
|
- except ValueError:
|
|
|
- ...
|
|
|
+ await sim_index.add_embeddings()
|
|
|
+ await sim_index.save()
|
|
|
+ logger.info("Similarity index ready")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|