|
|
@@ -1,13 +1,43 @@
|
|
|
import aiosqlite
|
|
|
+import asyncio
|
|
|
from datetime import datetime
|
|
|
|
|
|
from de_quoi_parle_le_monde.article import MainArticle, TopArticle, FeaturedArticle
|
|
|
from de_quoi_parle_le_monde.internet_archive import InternetArchiveSnapshotId
|
|
|
|
|
|
|
|
|
+class DbConnection:
|
|
|
+ def __init__(self, conn_str):
|
|
|
+ self.connection_string = conn_str
|
|
|
+ self.semaphore = asyncio.Semaphore(1)
|
|
|
+ self.conn = None
|
|
|
+
|
|
|
+ async def __aenter__(self):
|
|
|
+ await self.semaphore.acquire()
|
|
|
+ self.conn = await aiosqlite.connect(self.connection_string)
|
|
|
+ return self
|
|
|
+
|
|
|
+ async def __aexit__(self, exc_type, exc, tb):
|
|
|
+ await self.conn.close()
|
|
|
+ self.conn = None
|
|
|
+ self.semaphore.release()
|
|
|
+
|
|
|
+ async def execute(self, *args, **kwargs):
|
|
|
+ return await self.conn.execute(*args, **kwargs)
|
|
|
+
|
|
|
+ async def execute_fetchall(self, *args, **kwargs):
|
|
|
+ return await self.conn.execute_fetchall(*args, **kwargs)
|
|
|
+
|
|
|
+ async def execute_insert(self, *args, **kwargs):
|
|
|
+ return await self.conn.execute_insert(*args, **kwargs)
|
|
|
+
|
|
|
+ async def commit(self):
|
|
|
+ return await self.conn.commit()
|
|
|
+
|
|
|
+
|
|
|
class Storage:
|
|
|
def __init__(self):
|
|
|
- self.conn_str = "test.db"
|
|
|
+ self.conn = DbConnection("test.db")
|
|
|
|
|
|
@staticmethod
|
|
|
async def create():
|
|
|
@@ -16,7 +46,7 @@ class Storage:
|
|
|
return storage
|
|
|
|
|
|
async def _create_db(self):
|
|
|
- async with aiosqlite.connect(self.conn_str) as conn:
|
|
|
+ async with self.conn as conn:
|
|
|
await conn.execute(
|
|
|
"""
|
|
|
CREATE TABLE IF NOT EXISTS sites (
|
|
|
@@ -143,7 +173,7 @@ class Storage:
|
|
|
)
|
|
|
|
|
|
async def add_site(self, original_url: str) -> int:
|
|
|
- async with aiosqlite.connect(self.conn_str) as conn:
|
|
|
+ async with self.conn as conn:
|
|
|
(id_,) = await conn.execute_insert(
|
|
|
self._insert_stmt("sites", ["original_url"]),
|
|
|
[original_url],
|
|
|
@@ -165,7 +195,7 @@ class Storage:
|
|
|
async def add_snapshot(
|
|
|
self, site_id: int, snapshot: InternetArchiveSnapshotId, virtual: datetime
|
|
|
) -> int:
|
|
|
- async with aiosqlite.connect(self.conn_str) as conn:
|
|
|
+ async with self.conn as conn:
|
|
|
(id_,) = await conn.execute_insert(
|
|
|
self._insert_stmt(
|
|
|
"snapshots", ["timestamp", "site_id", "timestamp_virtual"]
|
|
|
@@ -187,7 +217,7 @@ class Storage:
|
|
|
return id_
|
|
|
|
|
|
async def add_featured_article(self, article: FeaturedArticle):
|
|
|
- async with aiosqlite.connect(self.conn_str) as conn:
|
|
|
+ async with self.conn as conn:
|
|
|
(id_,) = await conn.execute_insert(
|
|
|
self._insert_stmt("featured_articles", ["title", "url"]),
|
|
|
[article.title, article.url],
|
|
|
@@ -207,15 +237,19 @@ class Storage:
|
|
|
return id_
|
|
|
|
|
|
async def add_main_article(self, snapshot_id: int, article_id: int):
|
|
|
- async with aiosqlite.connect(self.conn_str) as conn:
|
|
|
+ async with self.conn as conn:
|
|
|
await conn.execute_insert(
|
|
|
- self._insert_stmt("main_articles", ["snapshot_id", "featured_article_id"]),
|
|
|
+ self._insert_stmt(
|
|
|
+ "main_articles", ["snapshot_id", "featured_article_id"]
|
|
|
+ ),
|
|
|
[snapshot_id, article_id],
|
|
|
)
|
|
|
await conn.commit()
|
|
|
|
|
|
- async def add_top_article(self, snapshot_id: int, article_id: int, article: TopArticle):
|
|
|
- async with aiosqlite.connect(self.conn_str) as conn:
|
|
|
+ async def add_top_article(
|
|
|
+ self, snapshot_id: int, article_id: int, article: TopArticle
|
|
|
+ ):
|
|
|
+ async with self.conn as conn:
|
|
|
await conn.execute_insert(
|
|
|
self._insert_stmt(
|
|
|
"top_articles", ["snapshot_id", "featured_article_id", "rank"]
|
|
|
@@ -225,7 +259,7 @@ class Storage:
|
|
|
await conn.commit()
|
|
|
|
|
|
async def select_from(self, table):
|
|
|
- async with aiosqlite.connect(self.conn_str) as conn:
|
|
|
+ async with self.conn as conn:
|
|
|
return await conn.execute_fetchall(
|
|
|
f"""
|
|
|
SELECT *
|