Sfoglia il codice sorgente

Some formating/linting

jherve 1 anno fa
parent
commit
4a4ea9bdc4

+ 3 - 3
src/de_quoi_parle_le_monde/similarity_search.py

@@ -16,7 +16,7 @@ class SimilaritySearch:
         embeds = await self.storage.list_all_articles_embeddings()
         if not embeds:
             msg = (
-                f"Did not find any embeddings in storage. "
+                "Did not find any embeddings in storage. "
                 "A plausible cause is that they have not been computed yet"
             )
             logger.error(msg)
@@ -48,14 +48,14 @@ class SimilaritySearch:
 
         all_titles = np.array([e["title_embedding"] for e in embeds])
         faiss.normalize_L2(all_titles)
-        D, I = self.index.search(np.array(all_titles), nb_results)
+        scores, indices = self.index.search(np.array(all_titles), nb_results)
 
         return [
             (
                 featured_article_snapshot_ids[idx],
                 [(int(i), d) for d, i in res if score_func(d)],
             )
-            for idx, res in enumerate(np.dstack((D, I)))
+            for idx, res in enumerate(np.dstack((scores, indices)))
         ]
 
     @classmethod

+ 13 - 6
src/de_quoi_parle_le_monde/storage.py

@@ -63,7 +63,7 @@ class Storage:
             "url_article",
             "is_main",
             "rank",
-        ]
+        ],
     }
 
     def __init__(self):
@@ -267,7 +267,14 @@ class Storage:
     ) -> int:
         return await self._insert_or_get(
             self._insert_stmt(
-                "snapshots", ["timestamp", "site_id", "timestamp_virtual", "url_original", "url_snapshot"]
+                "snapshots",
+                [
+                    "timestamp",
+                    "site_id",
+                    "timestamp_virtual",
+                    "url_original",
+                    "url_snapshot",
+                ],
             ),
             [snapshot.timestamp, site_id, virtual, snapshot.original, snapshot.url],
             """
@@ -333,7 +340,7 @@ class Storage:
     async def exists_snapshot(self, name: str, dt: datetime):
         async with self.conn as conn:
             exists = await conn.execute_fetchall(
-                f"""
+                """
                     SELECT 1
                     FROM snapshots snap
                     JOIN sites s ON s.id = snap.site_id
@@ -347,7 +354,7 @@ class Storage:
     async def list_all_featured_article_snapshots(self):
         async with self.conn as conn:
             rows = await conn.execute_fetchall(
-                f"""
+                """
                     SELECT *
                     FROM featured_article_snapshots
                 """,
@@ -380,7 +387,7 @@ class Storage:
     async def list_all_embedded_featured_article_snapshot_ids(self) -> list[int]:
         async with self.conn as conn:
             rows = await conn.execute_fetchall(
-                f"""
+                """
                     SELECT featured_article_snapshot_id
                     FROM articles_embeddings
                 """,
@@ -391,7 +398,7 @@ class Storage:
     async def list_all_articles_embeddings(self):
         async with self.conn as conn:
             rows = await conn.execute_fetchall(
-                f"""
+                """
                     SELECT *
                     FROM articles_embeddings
                 """,

+ 1 - 1
src/de_quoi_parle_le_monde/workers/snapshot.py

@@ -108,7 +108,7 @@ class SnapshotWorker:
             main_page = await self.parse(collection, closest)
             await self.store(main_page, collection, dt)
             logger.info(f"Snap for collection {collection.name} @ {dt} is stored")
-        except Exception as e:
+        except Exception:
             return
 
     @staticmethod