Parcourir la source

Rework SQLite backend to match the new spec

jherve il y a 1 an
Parent
commit
b7278f41b9

+ 27 - 6
src/de_quoi_parle_le_monde/db/sqlite.py

@@ -1,14 +1,11 @@
 import asyncio
 import aiosqlite
 
-from .connection import DbConnection
 
-
-class DbConnectionSQLite(DbConnection):
+class SqliteConnection:
     def __init__(self, conn_str):
         self.connection_string = conn_str
         self.semaphore = asyncio.Semaphore(1)
-        self.conn = None
 
     async def __aenter__(self):
         await self.semaphore.acquire()
@@ -16,6 +13,8 @@ class DbConnectionSQLite(DbConnection):
         return self
 
     async def __aexit__(self, exc_type, exc, tb):
+        # Reproduce asyncpg' behaviour where commit is implicit
+        await self.conn.commit()
         await self.conn.close()
         self.conn = None
         self.semaphore.release()
@@ -29,5 +28,27 @@ class DbConnectionSQLite(DbConnection):
     async def execute_insert(self, *args, **kwargs):
         return await self.conn.execute_insert(*args, **kwargs)
 
-    async def commit(self):
-        return await self.conn.commit()
+    def transaction(self):
+        class DummyTransaction:
+            async def __aenter__(self):
+                return self
+
+            async def __aexit__(self, exc_type, exc, tb):
+                return
+
+        return DummyTransaction()
+
+
+class SqliteBackend:
+    def __init__(self, conn_path):
+        self.conn_path = conn_path
+
+    def get_connection(self):
+        return SqliteConnection(self.conn_path)
+
+    @staticmethod
+    async def create(conn_path):
+        return SqliteBackend(conn_path)
+
+    async def close(self):
+        ...

+ 2 - 2
src/de_quoi_parle_le_monde/storage.py

@@ -10,7 +10,7 @@ from de_quoi_parle_le_monde.article import (
     FeaturedArticleSnapshot,
     FeaturedArticle,
 )
-from de_quoi_parle_le_monde.db.sqlite import DbConnectionSQLite
+from de_quoi_parle_le_monde.db.sqlite import SqliteBackend
 from de_quoi_parle_le_monde.db.postgres import PostgresBackend
 from de_quoi_parle_le_monde.internet_archive import InternetArchiveSnapshotId
 
@@ -312,7 +312,7 @@ class Storage:
             if conn_url.path.startswith("//"):
                 raise ValueError("Absolute URLs not supported for sqlite")
             elif conn_url.path.startswith("/"):
-                backend = DbConnectionSQLite(conn_url.path[1:])
+                backend = await SqliteBackend.create(conn_url.path[1:])
         elif conn_url.scheme == "postgresql":
             backend = await PostgresBackend.create(settings.database_url)
         else: