Просмотр исходного кода

Move DbConnection to an abstract + implementation scheme

jherve 1 год назад
Родитель
Сommit
d999f7b82b

+ 21 - 0
src/de_quoi_parle_le_monde/db/connection.py

@@ -0,0 +1,21 @@
+from abc import ABC, abstractmethod
+
+
+class DbConnection(ABC):
+    @abstractmethod
+    async def __aenter__(self): ...
+
+    @abstractmethod
+    async def __aexit__(self, exc_type, exc, tb): ...
+
+    @abstractmethod
+    async def execute(self, *args, **kwargs): ...
+
+    @abstractmethod
+    async def execute_fetchall(self, *args, **kwargs): ...
+
+    @abstractmethod
+    async def execute_insert(self, *args, **kwargs): ...
+
+    @abstractmethod
+    async def commit(self): ...

+ 33 - 0
src/de_quoi_parle_le_monde/db/sqlite.py

@@ -0,0 +1,33 @@
+import asyncio
+import aiosqlite
+
+from .connection import DbConnection
+
+
+class DbConnectionSQLite(DbConnection):
+    def __init__(self, conn_str):
+        self.connection_string = conn_str
+        self.semaphore = asyncio.Semaphore(1)
+        self.conn = None
+
+    async def __aenter__(self):
+        await self.semaphore.acquire()
+        self.conn = await aiosqlite.connect(self.connection_string)
+        return self
+
+    async def __aexit__(self, exc_type, exc, tb):
+        await self.conn.close()
+        self.conn = None
+        self.semaphore.release()
+
+    async def execute(self, *args, **kwargs):
+        return await self.conn.execute(*args, **kwargs)
+
+    async def execute_fetchall(self, *args, **kwargs):
+        return await self.conn.execute_fetchall(*args, **kwargs)
+
+    async def execute_insert(self, *args, **kwargs):
+        return await self.conn.execute_insert(*args, **kwargs)
+
+    async def commit(self):
+        return await self.conn.commit()

+ 2 - 32
src/de_quoi_parle_le_monde/storage.py

@@ -1,6 +1,4 @@
 from typing import Any
-import aiosqlite
-import asyncio
 from datetime import datetime
 import numpy as np
 from attrs import frozen
@@ -12,38 +10,10 @@ from de_quoi_parle_le_monde.article import (
     FeaturedArticleSnapshot,
     FeaturedArticle,
 )
+from de_quoi_parle_le_monde.db.sqlite import DbConnectionSQLite
 from de_quoi_parle_le_monde.internet_archive import InternetArchiveSnapshotId
 
 
-class DbConnection:
-    def __init__(self, conn_str):
-        self.connection_string = conn_str
-        self.semaphore = asyncio.Semaphore(1)
-        self.conn = None
-
-    async def __aenter__(self):
-        await self.semaphore.acquire()
-        self.conn = await aiosqlite.connect(self.connection_string)
-        return self
-
-    async def __aexit__(self, exc_type, exc, tb):
-        await self.conn.close()
-        self.conn = None
-        self.semaphore.release()
-
-    async def execute(self, *args, **kwargs):
-        return await self.conn.execute(*args, **kwargs)
-
-    async def execute_fetchall(self, *args, **kwargs):
-        return await self.conn.execute_fetchall(*args, **kwargs)
-
-    async def execute_insert(self, *args, **kwargs):
-        return await self.conn.execute_insert(*args, **kwargs)
-
-    async def commit(self):
-        return await self.conn.commit()
-
-
 @frozen
 class UniqueIndex:
     name: str
@@ -320,7 +290,7 @@ class Storage:
             if conn_url.path.startswith("//"):
                 raise ValueError("Absolute URLs not supported for sqlite")
             elif conn_url.path.startswith("/"):
-                self.conn = DbConnection(conn_url.path[1:])
+                self.conn = DbConnectionSQLite(conn_url.path[1:])
         else:
             raise ValueError("Only the SQLite backend is supported")