فهرست منبع

Add _insert_or_get method

jherve 1 سال پیش
والد
کامیت
60bca41b48
1فایلهای تغییر یافته به همراه46 افزوده شده و 61 حذف شده
  1. 46 61
      src/de_quoi_parle_le_monde/storage.py

+ 46 - 61
src/de_quoi_parle_le_monde/storage.py

@@ -1,3 +1,4 @@
+from typing import Any
 import aiosqlite
 import asyncio
 from datetime import datetime
@@ -193,93 +194,61 @@ class Storage:
             )
 
     async def add_site(self, original_url: str) -> int:
-        async with self.conn as conn:
-            (id_,) = await conn.execute_insert(
-                self._insert_stmt("sites", ["original_url"]),
-                [original_url],
-            )
-
-            if id_ == 0:
-                [(id_,)] = await conn.execute_fetchall(
-                    """
+        return await self._insert_or_get(
+            self._insert_stmt("sites", ["original_url"]),
+            [original_url],
+            """
                     SELECT id
                     FROM sites
                     WHERE original_url = ?
                     """,
-                    [original_url],
-                )
-
-            await conn.commit()
-            return id_
+            [original_url],
+        )
 
     async def add_snapshot(
         self, site_id: int, snapshot: InternetArchiveSnapshotId, virtual: datetime
     ) -> int:
-        async with self.conn as conn:
-            (id_,) = await conn.execute_insert(
-                self._insert_stmt(
-                    "snapshots", ["timestamp", "site_id", "timestamp_virtual"]
-                ),
-                [snapshot.timestamp, site_id, virtual],
-            )
-
-            if id_ == 0:
-                [(id_,)] = await conn.execute_fetchall(
-                    """
+        return await self._insert_or_get(
+            self._insert_stmt(
+                "snapshots", ["timestamp", "site_id", "timestamp_virtual"]
+            ),
+            [snapshot.timestamp, site_id, virtual],
+            """
                     SELECT id
                     FROM snapshots
                     WHERE timestamp_virtual = ? AND site_id = ?
                     """,
-                    [virtual, site_id],
-                )
-
-            await conn.commit()
-            return id_
+            [virtual, site_id],
+        )
 
     async def add_featured_article(self, article: FeaturedArticle):
-        async with self.conn as conn:
-            (id_,) = await conn.execute_insert(
-                self._insert_stmt("featured_articles", ["url"]),
-                [str(article.url)],
-            )
-
-            if id_ == 0:
-                [(id_,)] = await conn.execute_fetchall(
-                    """
+        return await self._insert_or_get(
+            self._insert_stmt("featured_articles", ["url"]),
+            [str(article.url)],
+            """
                     SELECT id
                     FROM featured_articles
                     WHERE url = ?
                     """,
-                    [str(article.url)],
-                )
-
-            await conn.commit()
-            return id_
+            [str(article.url)],
+        )
 
     async def add_featured_article_snapshot(
         self, featured_article_id: int, article: FeaturedArticleSnapshot
     ):
-        async with self.conn as conn:
-            (id_,) = await conn.execute_insert(
-                self._insert_stmt(
-                    "featured_article_snapshots",
-                    ["title", "url", "featured_article_id"],
-                ),
-                [article.title, article.url, featured_article_id],
-            )
-
-            if id_ == 0:
-                [(id_,)] = await conn.execute_fetchall(
-                    """
+        return await self._insert_or_get(
+            self._insert_stmt(
+                "featured_article_snapshots",
+                ["title", "url", "featured_article_id"],
+            ),
+            [article.title, article.url, featured_article_id],
+            """
                     SELECT id
                     FROM featured_article_snapshots
                     WHERE featured_article_id = ? AND url = ?
                     """,
-                    [featured_article_id, article.url],
-                )
-
-            await conn.commit()
-            return id_
+            [featured_article_id, article.url],
+        )
 
     async def add_main_article(self, snapshot_id: int, article_id: int):
         async with self.conn as conn:
@@ -313,6 +282,22 @@ class Storage:
                 """,
             )
 
+    async def _insert_or_get(
+        self,
+        insert_stmt: str,
+        insert_args: list[Any],
+        select_stmt: str,
+        select_args: list[Any],
+    ) -> int:
+        async with self.conn as conn:
+            (id_,) = await conn.execute_insert(insert_stmt, insert_args)
+
+            if id_ == 0:
+                [(id_,)] = await conn.execute_fetchall(select_stmt, select_args)
+
+            await conn.commit()
+            return id_
+
     @staticmethod
     def _insert_stmt(table, cols):
         cols_str = ", ".join(cols)