Преглед изворни кода

A better solution for rate limiting

jherve пре 1 година
родитељ
комит
b1a81e01b5

+ 0 - 54
src/de_quoi_parle_le_monde/http.py

@@ -1,54 +0,0 @@
-from attrs import define
-from aiohttp_client_cache import SQLiteBackend
-from aiohttp_client_cache.session import CacheMixin
-from aiohttp.client import ClientSession
-from aiolimiter import AsyncLimiter
-import asyncio
-
-
-class SemaphoreMixin:
-    async def _request(self, *args, **kwargs):
-        await self.sem.acquire()
-        req = await super()._request(*args, **kwargs)
-        self.sem.release()
-        return req
-
-
-class RateLimiterMixin:
-    async def _request(self, *args, **kwargs):
-        async with self.limiter:
-            return await super()._request(*args, **kwargs)
-
-
-@define
-class LimitedCachedSession(CacheMixin, SemaphoreMixin, RateLimiterMixin, ClientSession):
-    sem: asyncio.Semaphore
-    limiter: AsyncLimiter
-    cache: SQLiteBackend
-
-    def __init__(self):
-        self.sem = asyncio.Semaphore(5)
-        self.limiter = AsyncLimiter(2.0, 1.0)
-        super().__init__(cache=SQLiteBackend("http"))
-
-
-class HttpSession:
-    def __init__(self):
-        self.session = LimitedCachedSession()
-
-    async def get(self, url, params=None):
-        async with self.session.get(url, allow_redirects=True, params=params) as resp:
-            resp.raise_for_status()
-            return await resp.text()
-
-    async def __aenter__(self):
-        await self.session.__aenter__()
-        return self
-
-    async def __aexit__(self, exc_type, exc, tb):
-        return await self.session.__aexit__(exc_type, exc, tb)
-
-
-class HttpClient:
-    def session(self):
-        return HttpSession()

+ 33 - 3
src/de_quoi_parle_le_monde/internet_archive.py

@@ -2,8 +2,9 @@ from attrs import frozen, field
 from typing import Optional, ClassVar, NewType
 from typing import Optional, ClassVar, NewType
 from datetime import date, datetime, timedelta
 from datetime import date, datetime, timedelta
 import cattrs
 import cattrs
+from aiohttp.client import ClientSession, TCPConnector
+from aiolimiter import AsyncLimiter
 
 
-from de_quoi_parle_le_monde.http import HttpSession
 
 
 Timestamp = NewType("Timestamp", datetime)
 Timestamp = NewType("Timestamp", datetime)
 datetime_format = "%Y%m%d%H%M%S"
 datetime_format = "%Y%m%d%H%M%S"
@@ -93,10 +94,22 @@ class InternetArchiveSnapshot:
     text: str = field(repr=False)
     text: str = field(repr=False)
 
 
 
 
+class RateLimitedConnector(TCPConnector):
+    def __init__(self, *args, **kwargs):
+        limiter_max_rate = kwargs.pop("limiter_max_rate")
+        limiter_time_period = kwargs.pop("limiter_time_period", 60)
+        self._limiter = AsyncLimiter(limiter_max_rate, limiter_time_period)
+        super().__init__(*args, **kwargs)
+
+    async def connect(self, req, *args, **kwargs):
+        async with self._limiter:
+            return await super().connect(req, *args, **kwargs)
+
+
 @frozen
 @frozen
 class InternetArchiveClient:
 class InternetArchiveClient:
     # https://github.com/internetarchive/wayback/tree/master/wayback-cdx-server
     # https://github.com/internetarchive/wayback/tree/master/wayback-cdx-server
-    session: HttpSession
+    session: ClientSession
     search_url: ClassVar[str] = "http://web.archive.org/cdx/search/cdx"
     search_url: ClassVar[str] = "http://web.archive.org/cdx/search/cdx"
 
 
     async def search_snapshots(
     async def search_snapshots(
@@ -135,5 +148,22 @@ class InternetArchiveClient:
         else:
         else:
             raise SnapshotNotYetAvailable(dt)
             raise SnapshotNotYetAvailable(dt)
 
 
+    async def __aenter__(self):
+        await self.session.__aenter__()
+        return self
+
+    async def __aexit__(self, exc_type, exc, tb):
+        return await self.session.__aexit__(exc_type, exc, tb)
+
     async def _get(self, url, params=None):
     async def _get(self, url, params=None):
-        return await self.session.get(url, params)
+        async with self.session.get(url, allow_redirects=True, params=params) as resp:
+            resp.raise_for_status()
+            return await resp.text()
+
+    @staticmethod
+    def create(limiter_max_rate, limiter_time_period):
+        conn = RateLimitedConnector(
+            limiter_max_rate=limiter_max_rate, limiter_time_period=limiter_time_period
+        )
+        session = ClientSession(connector=conn)
+        return InternetArchiveClient(session)

+ 1 - 4
src/de_quoi_parle_le_monde/snapshots.py

@@ -9,7 +9,6 @@ from loguru import logger
 
 
 
 
 from de_quoi_parle_le_monde.article import ArchiveCollection
 from de_quoi_parle_le_monde.article import ArchiveCollection
-from de_quoi_parle_le_monde.http import HttpClient, HttpSession
 from de_quoi_parle_le_monde.internet_archive import (
 from de_quoi_parle_le_monde.internet_archive import (
     InternetArchiveClient,
     InternetArchiveClient,
     SnapshotNotYetAvailable,
     SnapshotNotYetAvailable,
@@ -132,14 +131,12 @@ class SnapshotWorker:
 
 
 
 
 async def main():
 async def main():
-    http_client = HttpClient()
     storage = await Storage.create()
     storage = await Storage.create()
 
 
     logger.info("Starting snapshot service..")
     logger.info("Starting snapshot service..")
     jobs = SnapshotJob.create(10, [8, 12, 18, 22])
     jobs = SnapshotJob.create(10, [8, 12, 18, 22])
 
 
-    async with http_client.session() as session:
-        ia = InternetArchiveClient(session)
+    async with InternetArchiveClient.create(1.0, 1.0) as ia:
         worker = SnapshotWorker(storage, ia)
         worker = SnapshotWorker(storage, ia)
         await asyncio.gather(*[worker.run(job) for job in jobs])
         await asyncio.gather(*[worker.run(job) for job in jobs])
     logger.info("Snapshot service exiting")
     logger.info("Snapshot service exiting")