|
|
@@ -11,7 +11,7 @@ from de_quoi_parle_le_monde.article import (
|
|
|
FeaturedArticle,
|
|
|
)
|
|
|
from de_quoi_parle_le_monde.db.sqlite import DbConnectionSQLite
|
|
|
-from de_quoi_parle_le_monde.db.postgres import DbConnectionPostgres
|
|
|
+from de_quoi_parle_le_monde.db.postgres import PostgresBackend
|
|
|
from de_quoi_parle_le_monde.internet_archive import InternetArchiveSnapshotId
|
|
|
|
|
|
|
|
|
@@ -296,7 +296,10 @@ class Storage:
|
|
|
]
|
|
|
|
|
|
def __init__(self, backend):
|
|
|
- self.conn = backend
|
|
|
+ self.backend = backend
|
|
|
+
|
|
|
+ async def close(self):
|
|
|
+ await self.backend.close()
|
|
|
|
|
|
@staticmethod
|
|
|
async def create():
|
|
|
@@ -311,7 +314,7 @@ class Storage:
|
|
|
elif conn_url.path.startswith("/"):
|
|
|
backend = DbConnectionSQLite(conn_url.path[1:])
|
|
|
elif conn_url.scheme == "postgresql":
|
|
|
- backend = DbConnectionPostgres(settings.database_url)
|
|
|
+ backend = await PostgresBackend.create(settings.database_url)
|
|
|
else:
|
|
|
raise ValueError("Only the SQLite backend is supported")
|
|
|
|
|
|
@@ -320,7 +323,7 @@ class Storage:
|
|
|
return storage
|
|
|
|
|
|
async def _create_db(self):
|
|
|
- async with self.conn as conn:
|
|
|
+ async with self.backend.get_connection() as conn:
|
|
|
for t in self.tables:
|
|
|
await t.create_if_not_exists(conn)
|
|
|
|
|
|
@@ -331,7 +334,7 @@ class Storage:
|
|
|
await v.create_if_not_exists(conn)
|
|
|
|
|
|
async def exists_snapshot(self, name: str, dt: datetime):
|
|
|
- async with self.conn as conn:
|
|
|
+ async with self.backend.get_connection() as conn:
|
|
|
exists = await conn.execute_fetchall(
|
|
|
"""
|
|
|
SELECT 1
|
|
|
@@ -346,7 +349,7 @@ class Storage:
|
|
|
return exists != []
|
|
|
|
|
|
async def list_all_featured_article_snapshots(self):
|
|
|
- async with self.conn as conn:
|
|
|
+ async with self.backend.get_connection() as conn:
|
|
|
rows = await conn.execute_fetchall(
|
|
|
"""
|
|
|
SELECT *
|
|
|
@@ -363,7 +366,7 @@ class Storage:
|
|
|
if len(featured_article_snapshot_ids) == 0:
|
|
|
return []
|
|
|
|
|
|
- async with self.conn as conn:
|
|
|
+ async with self.backend.get_connection() as conn:
|
|
|
rows = await conn.execute_fetchall(
|
|
|
f"""
|
|
|
SELECT *
|
|
|
@@ -385,7 +388,7 @@ class Storage:
|
|
|
return {col: r[idx] for idx, col in enumerate(columns)}
|
|
|
|
|
|
async def list_all_embedded_featured_article_snapshot_ids(self) -> list[int]:
|
|
|
- async with self.conn as conn:
|
|
|
+ async with self.backend.get_connection() as conn:
|
|
|
rows = await conn.execute_fetchall(
|
|
|
"""
|
|
|
SELECT featured_article_snapshot_id
|
|
|
@@ -396,7 +399,7 @@ class Storage:
|
|
|
return [r[0] for r in rows]
|
|
|
|
|
|
async def list_all_articles_embeddings(self):
|
|
|
- async with self.conn as conn:
|
|
|
+ async with self.backend.get_connection() as conn:
|
|
|
rows = await conn.execute_fetchall(
|
|
|
"""
|
|
|
SELECT *
|
|
|
@@ -407,7 +410,7 @@ class Storage:
|
|
|
return [self._from_articles_embeddings_row(r) for r in rows]
|
|
|
|
|
|
async def get_article_embedding(self, featured_article_snapshot_ids: list[int]):
|
|
|
- async with self.conn as conn:
|
|
|
+ async with self.backend.get_connection() as conn:
|
|
|
rows = await conn.execute_fetchall(
|
|
|
f"""
|
|
|
SELECT *
|
|
|
@@ -428,7 +431,7 @@ class Storage:
|
|
|
return d
|
|
|
|
|
|
async def add_embedding(self, featured_article_snapshot_id: int, embedding):
|
|
|
- async with self.conn as conn:
|
|
|
+ async with self.backend.get_connection() as conn:
|
|
|
await conn.execute_insert(
|
|
|
self._insert_stmt(
|
|
|
"articles_embeddings",
|
|
|
@@ -437,10 +440,9 @@ class Storage:
|
|
|
featured_article_snapshot_id,
|
|
|
embedding,
|
|
|
)
|
|
|
- await conn.commit()
|
|
|
|
|
|
async def list_sites(self):
|
|
|
- async with self.conn as conn:
|
|
|
+ async with self.backend.get_connection() as conn:
|
|
|
sites = await conn.execute_fetchall("SELECT * FROM sites")
|
|
|
return [self._from_row(s, self._table_by_name["sites"]) for s in sites]
|
|
|
|
|
|
@@ -449,7 +451,7 @@ class Storage:
|
|
|
site_id: int,
|
|
|
featured_article_snapshot_id: int | None = None,
|
|
|
):
|
|
|
- async with self.conn as conn:
|
|
|
+ async with self.backend.get_connection() as conn:
|
|
|
if featured_article_snapshot_id is None:
|
|
|
timestamp_query, timestamp_params = (
|
|
|
"""
|
|
|
@@ -512,26 +514,26 @@ class Storage:
|
|
|
]
|
|
|
|
|
|
async def add_page(self, collection, page, dt):
|
|
|
- async with self.conn as conn:
|
|
|
- site_id = await self._add_site(conn, collection.name, collection.url)
|
|
|
- snapshot_id = await self._add_snapshot(conn, site_id, page.snapshot.id, dt)
|
|
|
- article_id = await self._add_featured_article(
|
|
|
- conn, page.main_article.article.original
|
|
|
- )
|
|
|
- main_article_snap_id = await self._add_featured_article_snapshot(
|
|
|
- conn, article_id, page.main_article.article
|
|
|
- )
|
|
|
- await self._add_main_article(conn, snapshot_id, main_article_snap_id)
|
|
|
-
|
|
|
- for t in page.top_articles:
|
|
|
- article_id = await self._add_featured_article(conn, t.article.original)
|
|
|
- top_article_snap_id = await self._add_featured_article_snapshot(
|
|
|
- conn, article_id, t.article
|
|
|
+ async with self.backend.get_connection() as conn:
|
|
|
+ async with conn.transaction():
|
|
|
+ site_id = await self._add_site(conn, collection.name, collection.url)
|
|
|
+ snapshot_id = await self._add_snapshot(conn, site_id, page.snapshot.id, dt)
|
|
|
+ article_id = await self._add_featured_article(
|
|
|
+ conn, page.main_article.article.original
|
|
|
+ )
|
|
|
+ main_article_snap_id = await self._add_featured_article_snapshot(
|
|
|
+ conn, article_id, page.main_article.article
|
|
|
)
|
|
|
- await self._add_top_article(conn, snapshot_id, top_article_snap_id, t)
|
|
|
- await conn.commit()
|
|
|
+ await self._add_main_article(conn, snapshot_id, main_article_snap_id)
|
|
|
+
|
|
|
+ for t in page.top_articles:
|
|
|
+ article_id = await self._add_featured_article(conn, t.article.original)
|
|
|
+ top_article_snap_id = await self._add_featured_article_snapshot(
|
|
|
+ conn, article_id, t.article
|
|
|
+ )
|
|
|
+ await self._add_top_article(conn, snapshot_id, top_article_snap_id, t)
|
|
|
|
|
|
- return site_id
|
|
|
+ return site_id
|
|
|
|
|
|
async def _add_site(self, conn, name: str, original_url: str) -> int:
|
|
|
return await self._insert_or_get(
|