Selaa lähdekoodia

Proper rewrite of storage-related code for async connections

jherve 1 vuosi sitten
vanhempi
commit
55f6e6cb3b

+ 27 - 28
src/de_quoi_parle_le_monde/db/postgres.py

@@ -1,42 +1,41 @@
 import asyncpg
-import traceback
 
-from .connection import DbConnection
 
-
-class DbConnectionPostgres(DbConnection):
-    def __init__(self, conn_str):
-        self.connection_string = conn_str
-        self.conn = None
+class PostgresConnection:
+    def __init__(self, coro):
+        self.coro = coro
 
     async def __aenter__(self):
-        self.conn = await asyncpg.connect(self.connection_string)
+        self.conn = await self.coro.__aenter__()
         return self
 
     async def __aexit__(self, exc_type, exc, tb):
-        await self.conn.close()
-        self.conn = None
+        await self.coro.__aexit__(exc_type, exc, tb)
 
     async def execute(self, *args, **kwargs):
         return await self.conn.execute(*args, **kwargs)
 
+    async def execute_insert(self, *args, **kwargs):
+        return await self.conn.execute(*args, **kwargs)
+
     async def execute_fetchall(self, *args, **kwargs):
-        try:
-            res = await self.conn.fetch(*args, **kwargs)
-            return res
-        except Exception as e:
-            print("exception on exec of : ", args)
-            traceback.print_exception(e)
-            raise e
+        return await self.conn.fetch(*args, **kwargs)
 
-    async def execute_insert(self, *args, **kwargs):
-        try:
-            ret = await self.conn.execute(*args, **kwargs)
-            return ret
-        except Exception as e:
-            print("exception on exec of : ", args)
-            traceback.print_exception(e)
-            raise e
-
-    async def commit(self):
-        return
+    def transaction(self):
+        return self.conn.transaction()
+
+
+class PostgresBackend:
+    def __init__(self, pool):
+        self.pool = pool
+
+    def get_connection(self):
+        return PostgresConnection(self.pool.acquire())
+
+    @staticmethod
+    async def create(conn_url):
+        pool = await asyncpg.create_pool(conn_url)
+        return PostgresBackend(pool)
+
+    async def close(self):
+        await self.pool.close()

+ 1 - 0
src/de_quoi_parle_le_monde/snapshots.py

@@ -231,6 +231,7 @@ async def main():
             for t in tasks:
                 t.cancel()
 
+    await storage.close()
     logger.info("Snapshot service exiting")
 
 

+ 34 - 32
src/de_quoi_parle_le_monde/storage.py

@@ -11,7 +11,7 @@ from de_quoi_parle_le_monde.article import (
     FeaturedArticle,
 )
 from de_quoi_parle_le_monde.db.sqlite import DbConnectionSQLite
-from de_quoi_parle_le_monde.db.postgres import DbConnectionPostgres
+from de_quoi_parle_le_monde.db.postgres import PostgresBackend
 from de_quoi_parle_le_monde.internet_archive import InternetArchiveSnapshotId
 
 
@@ -296,7 +296,10 @@ class Storage:
     ]
 
     def __init__(self, backend):
-        self.conn = backend
+        self.backend = backend
+
+    async def close(self):
+        await self.backend.close()
 
     @staticmethod
     async def create():
@@ -311,7 +314,7 @@ class Storage:
             elif conn_url.path.startswith("/"):
                 backend = DbConnectionSQLite(conn_url.path[1:])
         elif conn_url.scheme == "postgresql":
-            backend = DbConnectionPostgres(settings.database_url)
+            backend = await PostgresBackend.create(settings.database_url)
         else:
             raise ValueError("Only the SQLite backend is supported")
 
@@ -320,7 +323,7 @@ class Storage:
         return storage
 
     async def _create_db(self):
-        async with self.conn as conn:
+        async with self.backend.get_connection() as conn:
             for t in self.tables:
                 await t.create_if_not_exists(conn)
 
@@ -331,7 +334,7 @@ class Storage:
                 await v.create_if_not_exists(conn)
 
     async def exists_snapshot(self, name: str, dt: datetime):
-        async with self.conn as conn:
+        async with self.backend.get_connection() as conn:
             exists = await conn.execute_fetchall(
                 """
                     SELECT 1
@@ -346,7 +349,7 @@ class Storage:
         return exists != []
 
     async def list_all_featured_article_snapshots(self):
-        async with self.conn as conn:
+        async with self.backend.get_connection() as conn:
             rows = await conn.execute_fetchall(
                 """
                     SELECT *
@@ -363,7 +366,7 @@ class Storage:
         if len(featured_article_snapshot_ids) == 0:
             return []
 
-        async with self.conn as conn:
+        async with self.backend.get_connection() as conn:
             rows = await conn.execute_fetchall(
                 f"""
                     SELECT *
@@ -385,7 +388,7 @@ class Storage:
         return {col: r[idx] for idx, col in enumerate(columns)}
 
     async def list_all_embedded_featured_article_snapshot_ids(self) -> list[int]:
-        async with self.conn as conn:
+        async with self.backend.get_connection() as conn:
             rows = await conn.execute_fetchall(
                 """
                     SELECT featured_article_snapshot_id
@@ -396,7 +399,7 @@ class Storage:
             return [r[0] for r in rows]
 
     async def list_all_articles_embeddings(self):
-        async with self.conn as conn:
+        async with self.backend.get_connection() as conn:
             rows = await conn.execute_fetchall(
                 """
                     SELECT *
@@ -407,7 +410,7 @@ class Storage:
             return [self._from_articles_embeddings_row(r) for r in rows]
 
     async def get_article_embedding(self, featured_article_snapshot_ids: list[int]):
-        async with self.conn as conn:
+        async with self.backend.get_connection() as conn:
             rows = await conn.execute_fetchall(
                 f"""
                     SELECT *
@@ -428,7 +431,7 @@ class Storage:
         return d
 
     async def add_embedding(self, featured_article_snapshot_id: int, embedding):
-        async with self.conn as conn:
+        async with self.backend.get_connection() as conn:
             await conn.execute_insert(
                 self._insert_stmt(
                     "articles_embeddings",
@@ -437,10 +440,9 @@ class Storage:
                 featured_article_snapshot_id,
                 embedding,
             )
-            await conn.commit()
 
     async def list_sites(self):
-        async with self.conn as conn:
+        async with self.backend.get_connection() as conn:
             sites = await conn.execute_fetchall("SELECT * FROM sites")
             return [self._from_row(s, self._table_by_name["sites"]) for s in sites]
 
@@ -449,7 +451,7 @@ class Storage:
         site_id: int,
         featured_article_snapshot_id: int | None = None,
     ):
-        async with self.conn as conn:
+        async with self.backend.get_connection() as conn:
             if featured_article_snapshot_id is None:
                 timestamp_query, timestamp_params = (
                     """
@@ -512,26 +514,26 @@ class Storage:
             ]
 
     async def add_page(self, collection, page, dt):
-        async with self.conn as conn:
-            site_id = await self._add_site(conn, collection.name, collection.url)
-            snapshot_id = await self._add_snapshot(conn, site_id, page.snapshot.id, dt)
-            article_id = await self._add_featured_article(
-                conn, page.main_article.article.original
-            )
-            main_article_snap_id = await self._add_featured_article_snapshot(
-                conn, article_id, page.main_article.article
-            )
-            await self._add_main_article(conn, snapshot_id, main_article_snap_id)
-
-            for t in page.top_articles:
-                article_id = await self._add_featured_article(conn, t.article.original)
-                top_article_snap_id = await self._add_featured_article_snapshot(
-                    conn, article_id, t.article
+        async with self.backend.get_connection() as conn:
+            async with conn.transaction():
+                site_id = await self._add_site(conn, collection.name, collection.url)
+                snapshot_id = await self._add_snapshot(conn, site_id, page.snapshot.id, dt)
+                article_id = await self._add_featured_article(
+                    conn, page.main_article.article.original
+                )
+                main_article_snap_id = await self._add_featured_article_snapshot(
+                    conn, article_id, page.main_article.article
                 )
-                await self._add_top_article(conn, snapshot_id, top_article_snap_id, t)
-            await conn.commit()
+                await self._add_main_article(conn, snapshot_id, main_article_snap_id)
+
+                for t in page.top_articles:
+                    article_id = await self._add_featured_article(conn, t.article.original)
+                    top_article_snap_id = await self._add_featured_article_snapshot(
+                        conn, article_id, t.article
+                    )
+                    await self._add_top_article(conn, snapshot_id, top_article_snap_id, t)
 
-            return site_id
+        return site_id
 
     async def _add_site(self, conn, name: str, original_url: str) -> int:
         return await self._insert_or_get(