From: Michael Tremer Date: Tue, 21 Jan 2025 14:36:20 +0000 (+0000) Subject: ratelimiter: Allow passing arbitrary arguments X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=6b497fdce838af858d39c4864530b062a75bb9a7;p=pbs.git ratelimiter: Allow passing arbitrary arguments Signed-off-by: Michael Tremer --- diff --git a/src/buildservice/ratelimiter.py b/src/buildservice/ratelimiter.py index 965bb840..c647744e 100644 --- a/src/buildservice/ratelimiter.py +++ b/src/buildservice/ratelimiter.py @@ -54,9 +54,11 @@ ratelimiter = sqlalchemy.Table( class RateLimiter(base.Object): - def handle_request(self, request, handler, minutes, limit, **kwargs): - return RateLimiterRequest(self.backend, request, handler, - minutes=minutes, limit=limit, **kwargs) + def __call__(self, *args, **kwargs): + """ + Launch a new request + """ + return RateLimiterRequest(self.backend, *args, **kwargs) async def cleanup(backend): """ @@ -76,7 +78,7 @@ class RateLimiter(base.Object): class RateLimiterRequest(base.Object): - def init(self, request, handler, minutes, limit, key=None): + def init(self, request, handler, *, minutes, limit, key=None): self.request = request self.handler = handler diff --git a/src/web/auth.py b/src/web/auth.py index 60653895..19dd17cb 100644 --- a/src/web/auth.py +++ b/src/web/auth.py @@ -16,7 +16,7 @@ class LoginHandler(base.KerberosAuthMixin, base.BaseHandler): await self.render("login.html", username=username, failed=failed) - @base.ratelimit(requests=10, minutes=5) + @base.ratelimit(limit=10, minutes=5) async def post(self): # Fetch credentials username = self.get_argument("username") diff --git a/src/web/base.py b/src/web/base.py index 357922d7..5f914c0f 100644 --- a/src/web/base.py +++ b/src/web/base.py @@ -820,16 +820,14 @@ class ratelimit(object): """ A decorator class which limits how often a function can be called """ - def __init__(self, *, minutes, requests): - self.minutes = minutes - self.requests = requests + def __init__(self, **kwargs): + self.kwargs = kwargs def __call__(self, method): @functools.wraps(method) async def wrapper(handler, *args, **kwargs): # Pass the request to the rate limiter and get a request object - req = handler.backend.ratelimiter.handle_request(handler.request, - handler, minutes=self.minutes, limit=self.requests) + req = handler.backend.ratelimiter(handler.request, handler, **self.kwargs) # If the rate limit has been reached, we won't allow # processing the request and therefore send HTTP error code 429.