Bladeren bron

Prevent database lock with a semaphore on SQL connection

jherve 1 jaar geleden
bovenliggende
commit
94a61c82db
1 gewijzigde bestanden met toevoegingen van 44 en 10 verwijderingen
  1. 44 10
      src/de_quoi_parle_le_monde/storage.py

+ 44 - 10
src/de_quoi_parle_le_monde/storage.py

@@ -1,13 +1,43 @@
 import aiosqlite
+import asyncio
 from datetime import datetime
 
 from de_quoi_parle_le_monde.article import MainArticle, TopArticle, FeaturedArticle
 from de_quoi_parle_le_monde.internet_archive import InternetArchiveSnapshotId
 
 
+class DbConnection:
+    def __init__(self, conn_str):
+        self.connection_string = conn_str
+        self.semaphore = asyncio.Semaphore(1)
+        self.conn = None
+
+    async def __aenter__(self):
+        await self.semaphore.acquire()
+        self.conn = await aiosqlite.connect(self.connection_string)
+        return self
+
+    async def __aexit__(self, exc_type, exc, tb):
+        await self.conn.close()
+        self.conn = None
+        self.semaphore.release()
+
+    async def execute(self, *args, **kwargs):
+        return await self.conn.execute(*args, **kwargs)
+
+    async def execute_fetchall(self, *args, **kwargs):
+        return await self.conn.execute_fetchall(*args, **kwargs)
+
+    async def execute_insert(self, *args, **kwargs):
+        return await self.conn.execute_insert(*args, **kwargs)
+
+    async def commit(self):
+        return await self.conn.commit()
+
+
 class Storage:
     def __init__(self):
-        self.conn_str = "test.db"
+        self.conn = DbConnection("test.db")
 
     @staticmethod
     async def create():
@@ -16,7 +46,7 @@ class Storage:
         return storage
 
     async def _create_db(self):
-        async with aiosqlite.connect(self.conn_str) as conn:
+        async with self.conn as conn:
             await conn.execute(
                 """
                 CREATE TABLE IF NOT EXISTS sites (
@@ -143,7 +173,7 @@ class Storage:
             )
 
     async def add_site(self, original_url: str) -> int:
-        async with aiosqlite.connect(self.conn_str) as conn:
+        async with self.conn as conn:
             (id_,) = await conn.execute_insert(
                 self._insert_stmt("sites", ["original_url"]),
                 [original_url],
@@ -165,7 +195,7 @@ class Storage:
     async def add_snapshot(
         self, site_id: int, snapshot: InternetArchiveSnapshotId, virtual: datetime
     ) -> int:
-        async with aiosqlite.connect(self.conn_str) as conn:
+        async with self.conn as conn:
             (id_,) = await conn.execute_insert(
                 self._insert_stmt(
                     "snapshots", ["timestamp", "site_id", "timestamp_virtual"]
@@ -187,7 +217,7 @@ class Storage:
             return id_
 
     async def add_featured_article(self, article: FeaturedArticle):
-        async with aiosqlite.connect(self.conn_str) as conn:
+        async with self.conn as conn:
             (id_,) = await conn.execute_insert(
                 self._insert_stmt("featured_articles", ["title", "url"]),
                 [article.title, article.url],
@@ -207,15 +237,19 @@ class Storage:
             return id_
 
     async def add_main_article(self, snapshot_id: int, article_id: int):
-        async with aiosqlite.connect(self.conn_str) as conn:
+        async with self.conn as conn:
             await conn.execute_insert(
-                self._insert_stmt("main_articles", ["snapshot_id", "featured_article_id"]),
+                self._insert_stmt(
+                    "main_articles", ["snapshot_id", "featured_article_id"]
+                ),
                 [snapshot_id, article_id],
             )
             await conn.commit()
 
-    async def add_top_article(self, snapshot_id: int, article_id: int, article: TopArticle):
-        async with aiosqlite.connect(self.conn_str) as conn:
+    async def add_top_article(
+        self, snapshot_id: int, article_id: int, article: TopArticle
+    ):
+        async with self.conn as conn:
             await conn.execute_insert(
                 self._insert_stmt(
                     "top_articles", ["snapshot_id", "featured_article_id", "rank"]
@@ -225,7 +259,7 @@ class Storage:
             await conn.commit()
 
     async def select_from(self, table):
-        async with aiosqlite.connect(self.conn_str) as conn:
+        async with self.conn as conn:
             return await conn.execute_fetchall(
                 f"""
                     SELECT *