Преглед на файлове

Update all queries to match postgres' dialect

jherve преди 1 година
родител
ревизия
78907db424
променени са 1 файла, в които са добавени 30 реда и са изтрити 24 реда
  1. 30 24
      src/de_quoi_parle_le_monde/storage.py

+ 30 - 24
src/de_quoi_parle_le_monde/storage.py

@@ -334,9 +334,10 @@ class Storage:
                     SELECT 1
                     FROM snapshots snap
                     JOIN sites s ON s.id = snap.site_id
-                    WHERE s.name = ? AND timestamp_virtual = ?
+                    WHERE s.name = $1 AND timestamp_virtual = $2
                 """,
-                [name, dt],
+                name,
+                dt,
             )
 
         return exists != []
@@ -360,14 +361,13 @@ class Storage:
             return []
 
         async with self.conn as conn:
-            placeholders = ", ".join(["?" for _ in featured_article_snapshot_ids])
             rows = await conn.execute_fetchall(
                 f"""
                     SELECT *
                     FROM snapshot_apparitions
-                    WHERE featured_article_snapshot_id IN ({placeholders})
+                    WHERE featured_article_snapshot_id IN ({self._placeholders(*featured_article_snapshot_ids)})
                 """,
-                featured_article_snapshot_ids,
+                *featured_article_snapshot_ids,
             )
 
             return [
@@ -405,14 +405,13 @@ class Storage:
 
     async def get_article_embedding(self, featured_article_snapshot_ids: list[int]):
         async with self.conn as conn:
-            placeholders = ", ".join(["?" for _ in featured_article_snapshot_ids])
             rows = await conn.execute_fetchall(
                 f"""
                     SELECT *
                     FROM articles_embeddings
-                    WHERE featured_article_snapshot_id IN ({placeholders})
+                    WHERE featured_article_snapshot_id IN ({self._placeholders(*featured_article_snapshot_ids)})
                 """,
-                featured_article_snapshot_ids,
+                *featured_article_snapshot_ids,
             )
 
             return [self._from_articles_embeddings_row(r) for r in rows]
@@ -432,7 +431,8 @@ class Storage:
                     "articles_embeddings",
                     ["featured_article_snapshot_id", "title_embedding"],
                 ),
-                [featured_article_snapshot_id, embedding],
+                featured_article_snapshot_id,
+                embedding,
             )
             await conn.commit()
 
@@ -452,7 +452,7 @@ class Storage:
                     """
                     SELECT timestamp_virtual
                     FROM snapshot_apparitions sav
-                    WHERE is_main AND site_id = ?
+                    WHERE is_main AND site_id = $1
                     ORDER BY timestamp_virtual DESC
                     LIMIT 1
                     """,
@@ -463,7 +463,7 @@ class Storage:
                     """
                     SELECT timestamp_virtual
                     FROM snapshot_apparitions sav
-                    WHERE is_main AND site_id = ? AND featured_article_snapshot_id = ?
+                    WHERE is_main AND site_id = $1 AND featured_article_snapshot_id = $2
                     """,
                     [site_id, featured_article_snapshot_id],
                 )
@@ -477,7 +477,7 @@ class Storage:
                 WITH original_timestamp AS (
                     {timestamp_query}
                 ), sav_diff AS (
-                    SELECT sav.*, unixepoch(sav.timestamp_virtual) - unixepoch((SELECT * FROM original_timestamp)) AS time_diff
+                    SELECT sav.*, EXTRACT(EPOCH FROM sav.timestamp_virtual - (SELECT * FROM original_timestamp)) :: integer AS time_diff
                     FROM snapshot_apparitions sav
                 )
                 SELECT * FROM (
@@ -487,19 +487,19 @@ class Storage:
                 UNION ALL
                 SELECT * FROM (
                     SELECT * FROM sav_diff
-                    WHERE is_main AND site_id = ? AND time_diff > 0
+                    WHERE is_main AND site_id = $1 AND time_diff > 0
                     ORDER BY time_diff
                     LIMIT 1
                 )
                 UNION ALL
                 SELECT * FROM (
                     SELECT * FROM sav_diff
-                    WHERE is_main AND site_id = ? AND time_diff < 0
+                    WHERE is_main AND site_id = $1 AND time_diff < 0
                     ORDER BY time_diff DESC
                     LIMIT 1
                 )
                 """,
-                timestamp_params + [site_id, site_id],
+                *(timestamp_params),
             )
 
             return [
@@ -535,7 +535,7 @@ class Storage:
             conn,
             self._insert_stmt("sites", ["name", "original_url"]),
             [name, original_url],
-            "SELECT id FROM sites WHERE name = ?",
+            "SELECT id FROM sites WHERE name = $1",
             [name],
         )
 
@@ -555,7 +555,7 @@ class Storage:
                 ],
             ),
             [snapshot.timestamp, site_id, virtual, snapshot.original, snapshot.url],
-            "SELECT id FROM snapshots WHERE timestamp_virtual = ? AND site_id = ?",
+            "SELECT id FROM snapshots WHERE timestamp_virtual = $1 AND site_id = $2",
             [virtual, site_id],
         )
 
@@ -564,7 +564,7 @@ class Storage:
             conn,
             self._insert_stmt("featured_articles", ["url"]),
             [str(article.url)],
-            "SELECT id FROM featured_articles WHERE url = ?",
+            "SELECT id FROM featured_articles WHERE url = $1",
             [str(article.url)],
         )
 
@@ -578,7 +578,7 @@ class Storage:
                 ["title", "url", "featured_article_id"],
             ),
             [article.title, article.url, featured_article_id],
-            "SELECT id FROM featured_article_snapshots WHERE featured_article_id = ? AND url = ?",
+            "SELECT id FROM featured_article_snapshots WHERE featured_article_id = $1 AND url = $2",
             [featured_article_id, article.url],
         )
 
@@ -587,7 +587,8 @@ class Storage:
             self._insert_stmt(
                 "main_articles", ["snapshot_id", "featured_article_snapshot_id"]
             ),
-            [snapshot_id, article_id],
+            snapshot_id,
+            article_id,
         )
 
     async def _add_top_article(
@@ -598,7 +599,9 @@ class Storage:
                 "top_articles",
                 ["snapshot_id", "featured_article_snapshot_id", "rank"],
             ),
-            [snapshot_id, article_id, article.rank],
+            snapshot_id,
+            article_id,
+            article.rank,
         )
 
     async def _insert_or_get(
@@ -618,13 +621,16 @@ class Storage:
     @staticmethod
     def _insert_stmt(table, cols):
         cols_str = ", ".join(cols)
-        placeholders = ", ".join(("?" for c in cols))
         return f"""
             INSERT INTO {table} ({cols_str})
-            VALUES ({placeholders})
-            ON CONFLICT DO NOTHING;
+            VALUES ({Storage._placeholders(*cols)})
+            ON CONFLICT DO NOTHING
         """
 
+    @staticmethod
+    def _placeholders(*args):
+        return ", ".join([f"${idx + 1}" for idx, _ in enumerate(args)])
+
     @property
     def _table_by_name(self):
         return {t.name: t for t in self.tables}