Jelajahi Sumber

A better solution for rate limiting

jherve 1 tahun lalu
induk
melakukan
b1a81e01b5

+ 0 - 54
src/de_quoi_parle_le_monde/http.py

@@ -1,54 +0,0 @@
-from attrs import define
-from aiohttp_client_cache import SQLiteBackend
-from aiohttp_client_cache.session import CacheMixin
-from aiohttp.client import ClientSession
-from aiolimiter import AsyncLimiter
-import asyncio
-
-
-class SemaphoreMixin:
-    async def _request(self, *args, **kwargs):
-        await self.sem.acquire()
-        req = await super()._request(*args, **kwargs)
-        self.sem.release()
-        return req
-
-
-class RateLimiterMixin:
-    async def _request(self, *args, **kwargs):
-        async with self.limiter:
-            return await super()._request(*args, **kwargs)
-
-
-@define
-class LimitedCachedSession(CacheMixin, SemaphoreMixin, RateLimiterMixin, ClientSession):
-    sem: asyncio.Semaphore
-    limiter: AsyncLimiter
-    cache: SQLiteBackend
-
-    def __init__(self):
-        self.sem = asyncio.Semaphore(5)
-        self.limiter = AsyncLimiter(2.0, 1.0)
-        super().__init__(cache=SQLiteBackend("http"))
-
-
-class HttpSession:
-    def __init__(self):
-        self.session = LimitedCachedSession()
-
-    async def get(self, url, params=None):
-        async with self.session.get(url, allow_redirects=True, params=params) as resp:
-            resp.raise_for_status()
-            return await resp.text()
-
-    async def __aenter__(self):
-        await self.session.__aenter__()
-        return self
-
-    async def __aexit__(self, exc_type, exc, tb):
-        return await self.session.__aexit__(exc_type, exc, tb)
-
-
-class HttpClient:
-    def session(self):
-        return HttpSession()

+ 33 - 3
src/de_quoi_parle_le_monde/internet_archive.py

@@ -2,8 +2,9 @@ from attrs import frozen, field
 from typing import Optional, ClassVar, NewType
 from datetime import date, datetime, timedelta
 import cattrs
+from aiohttp.client import ClientSession, TCPConnector
+from aiolimiter import AsyncLimiter
 
-from de_quoi_parle_le_monde.http import HttpSession
 
 Timestamp = NewType("Timestamp", datetime)
 datetime_format = "%Y%m%d%H%M%S"
@@ -93,10 +94,22 @@ class InternetArchiveSnapshot:
     text: str = field(repr=False)
 
 
+class RateLimitedConnector(TCPConnector):
+    def __init__(self, *args, **kwargs):
+        limiter_max_rate = kwargs.pop("limiter_max_rate")
+        limiter_time_period = kwargs.pop("limiter_time_period", 60)
+        self._limiter = AsyncLimiter(limiter_max_rate, limiter_time_period)
+        super().__init__(*args, **kwargs)
+
+    async def connect(self, req, *args, **kwargs):
+        async with self._limiter:
+            return await super().connect(req, *args, **kwargs)
+
+
 @frozen
 class InternetArchiveClient:
     # https://github.com/internetarchive/wayback/tree/master/wayback-cdx-server
-    session: HttpSession
+    session: ClientSession
     search_url: ClassVar[str] = "http://web.archive.org/cdx/search/cdx"
 
     async def search_snapshots(
@@ -135,5 +148,22 @@ class InternetArchiveClient:
         else:
             raise SnapshotNotYetAvailable(dt)
 
+    async def __aenter__(self):
+        await self.session.__aenter__()
+        return self
+
+    async def __aexit__(self, exc_type, exc, tb):
+        return await self.session.__aexit__(exc_type, exc, tb)
+
     async def _get(self, url, params=None):
-        return await self.session.get(url, params)
+        async with self.session.get(url, allow_redirects=True, params=params) as resp:
+            resp.raise_for_status()
+            return await resp.text()
+
+    @staticmethod
+    def create(limiter_max_rate, limiter_time_period):
+        conn = RateLimitedConnector(
+            limiter_max_rate=limiter_max_rate, limiter_time_period=limiter_time_period
+        )
+        session = ClientSession(connector=conn)
+        return InternetArchiveClient(session)

+ 1 - 4
src/de_quoi_parle_le_monde/snapshots.py

@@ -9,7 +9,6 @@ from loguru import logger
 
 
 from de_quoi_parle_le_monde.article import ArchiveCollection
-from de_quoi_parle_le_monde.http import HttpClient, HttpSession
 from de_quoi_parle_le_monde.internet_archive import (
     InternetArchiveClient,
     SnapshotNotYetAvailable,
@@ -132,14 +131,12 @@ class SnapshotWorker:
 
 
 async def main():
-    http_client = HttpClient()
     storage = await Storage.create()
 
     logger.info("Starting snapshot service..")
     jobs = SnapshotJob.create(10, [8, 12, 18, 22])
 
-    async with http_client.session() as session:
-        ia = InternetArchiveClient(session)
+    async with InternetArchiveClient.create(1.0, 1.0) as ia:
         worker = SnapshotWorker(storage, ia)
         await asyncio.gather(*[worker.run(job) for job in jobs])
     logger.info("Snapshot service exiting")