Browse Source

Add column names info

jherve 1 year ago
parent
commit
2858f65987
1 changed files with 60 additions and 29 deletions
  1. 60 29
      src/de_quoi_parle_le_monde/storage.py

+ 60 - 29
src/de_quoi_parle_le_monde/storage.py

@@ -70,6 +70,10 @@ class Table:
     columns: list[Column]
     indexes: list[UniqueIndex]
 
+    @property
+    def column_names(self):
+        return [c.name for c in self.columns]
+
     async def create_if_not_exists(self, conn):
         cols = ",\n".join([f"{c.name} {c.attrs}" for c in self.columns])
         await conn.execute(f"""
@@ -86,6 +90,7 @@ class Table:
 class View:
     name: str
     create_stmt: str
+    column_names: list[str]
 
     async def create_if_not_exists(self, conn):
         stmt = f"""
@@ -96,29 +101,6 @@ class View:
 
 
 class Storage:
-    columns = {
-        "featured_article_snapshots": ["id", "featured_article_id", "title", "url"],
-        "articles_embeddings": [
-            "id",
-            "featured_article_snapshot_id",
-            "title_embedding",
-        ],
-        "snapshot_apparitions": [
-            "snapshot_id",
-            "site_id",
-            "site_name",
-            "site_original_url",
-            "timestamp",
-            "timestamp_virtual",
-            "featured_article_snapshot_id",
-            "featured_article_id",
-            "title",
-            "url_archive",
-            "url_article",
-            "is_main",
-            "rank",
-        ],
-    }
     tables = [
         Table(
             name="sites",
@@ -254,6 +236,14 @@ class Storage:
     views = [
         View(
             name="snapshots_view",
+            column_names=[
+                "id",
+                "site_id",
+                "site_name",
+                "site_original_url",
+                "timestamp",
+                "timestamp_virtual",
+            ],
             create_stmt="""
                 SELECT
                     s.id,
@@ -270,6 +260,16 @@ class Storage:
         ),
         View(
             name="main_page_apparitions",
+            column_names=[
+                "id",
+                "featured_article_id",
+                "title",
+                "url_archive",
+                "url_article",
+                "main_in_snapshot_id",
+                "top_in_snapshot_id",
+                "rank",
+            ],
             create_stmt="""
                 SELECT
                     fas.id,
@@ -288,6 +288,21 @@ class Storage:
         ),
         View(
             name="snapshot_apparitions",
+            column_names=[
+                "snapshot_id",
+                "site_id",
+                "site_name",
+                "site_original_url",
+                "timestamp",
+                "timestamp_virtual",
+                "featured_article_snapshot_id",
+                "featured_article_id",
+                "title",
+                "url_archive",
+                "url_article",
+                "is_main",
+                "rank",
+            ],
             create_stmt="""
                 SELECT
                     sv.id as snapshot_id,
@@ -436,7 +451,10 @@ class Storage:
                 """,
             )
 
-            return [self._from_row(r, "featured_article_snapshots") for r in rows]
+            return [
+                self._from_row(r, self._table_by_name["featured_article_snapshots"])
+                for r in rows
+            ]
 
     async def list_snapshot_apparitions(self, featured_article_snapshot_ids: list[int]):
         async with self.conn as conn:
@@ -450,11 +468,14 @@ class Storage:
                 featured_article_snapshot_ids,
             )
 
-            return [self._from_row(r, "snapshot_apparitions") for r in rows]
+            return [
+                self._from_row(r, self._view_by_name["snapshot_apparitions"])
+                for r in rows
+            ]
 
     @classmethod
-    def _from_row(cls, r, table_or_view: str):
-        columns = cls.columns[table_or_view]
+    def _from_row(cls, r, table_or_view: Table | View):
+        columns = table_or_view.column_names
 
         return {col: r[idx] for idx, col in enumerate(columns)}
 
@@ -496,7 +517,8 @@ class Storage:
 
     @classmethod
     def _from_articles_embeddings_row(cls, r):
-        d = cls._from_row(r, "articles_embeddings")
+        [embeds_view] = [v for v in cls.views if v.name == "articles_embeddings"]
+        d = cls._from_row(r, embeds_view)
         d.update(title_embedding=np.frombuffer(d["title_embedding"], dtype="float32"))
 
         return d
@@ -579,7 +601,8 @@ class Storage:
             )
 
             return [
-                self._from_row(a, "snapshot_apparitions") | {"time_diff": a[13]}
+                self._from_row(a, self._view_by_name["snapshot_apparitions"])
+                | {"time_diff": a[13]}
                 for a in main_articles
             ]
 
@@ -617,3 +640,11 @@ class Storage:
             VALUES ({placeholders})
             ON CONFLICT DO NOTHING;
         """
+
+    @property
+    def _table_by_name(self):
+        return {t.name: t for t in self.tables}
+
+    @property
+    def _view_by_name(self):
+        return {v.name: v for v in self.views}