From: Michael Tremer Date: Fri, 11 Aug 2023 14:45:56 +0000 (+0000) Subject: web: Create a simple rate limiter X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4f1cef10ba75bbd0a6c4860af67d0b36a325478c;p=pbs.git web: Create a simple rate limiter Signed-off-by: Michael Tremer --- diff --git a/Makefile.am b/Makefile.am index 1b0031c7..eb6f0534 100644 --- a/Makefile.am +++ b/Makefile.am @@ -102,6 +102,7 @@ buildservice_PYTHON = \ src/buildservice/mirrors.py \ src/buildservice/misc.py \ src/buildservice/packages.py \ + src/buildservice/ratelimiter.py \ src/buildservice/releasemonitoring.py \ src/buildservice/repository.py \ src/buildservice/sessions.py \ diff --git a/src/buildservice/__init__.py b/src/buildservice/__init__.py index b69493fc..bc5aa31a 100644 --- a/src/buildservice/__init__.py +++ b/src/buildservice/__init__.py @@ -29,6 +29,7 @@ from . import logstreams from . import messages from . import mirrors from . import packages +from . import ratelimiter from . import releasemonitoring from . import repository from . import settings @@ -80,6 +81,7 @@ class Backend(object): self.mirrors = mirrors.Mirrors(self) self.packages = packages.Packages(self) self.monitorings = releasemonitoring.Monitorings(self) + self.ratelimiter = ratelimiter.RateLimiter(self) self.repos = repository.Repositories(self) self.sessions = sessions.Sessions(self) self.sources = sources.Sources(self) diff --git a/src/buildservice/cache.py b/src/buildservice/cache.py index 0b70c4f2..a59b9af8 100644 --- a/src/buildservice/cache.py +++ b/src/buildservice/cache.py @@ -115,3 +115,13 @@ class Cache(object): conn = await self.connection() return await conn.delete(key) + + async def transaction(self, *args, **kwargs): + conn = await self.connection() + + return await conn.transaction(*args, **kwargs) + + async def pipeline(self, *args, **kwargs): + conn = await self.connection() + + return conn.pipeline(*args, **kwargs) diff --git a/src/buildservice/ratelimiter.py b/src/buildservice/ratelimiter.py new file mode 100644 index 00000000..5cfd6fa1 --- /dev/null +++ b/src/buildservice/ratelimiter.py @@ -0,0 +1,117 @@ +#!/usr/bin/python3 +############################################################################### +# # +# Pakfire - The IPFire package management system # +# Copyright (C) 2022 Pakfire development team # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see . # +# # +############################################################################### + +import datetime + +from . import base + +class RateLimiter(base.Object): + def handle_request(self, request, handler, minutes, limit): + return RateLimiterRequest(self.backend, request, handler, + minutes=minutes, limit=limit) + + +class RateLimiterRequest(base.Object): + def init(self, request, handler, minutes, limit): + self.request = request + self.handler = handler + + # Save the limits + self.minutes = minutes + self.limit = limit + + # What is the current time? + self.now = datetime.datetime.utcnow() + + # When to expire? + self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1) + + self.prefix = "-".join(( + self.__class__.__name__, + self.request.host, + self.request.path, + self.request.method, + self.request.remote_ip, + )) + + async def is_ratelimited(self): + """ + Returns True if the request is prohibited by the rate limiter + """ + counter = await self.get_counter() + + # The client is rate-limited when more requests have been + # received than allowed. + if counter >= self.limit: + return True + + # Increment the counter + await self.increment_counter() + + # If not ratelimited, write some headers + self.write_headers(counter=counter) + + @property + def key(self): + return "%s-%s" % (self.prefix, self.now.strftime("%Y-%m-%d-%H:%M")) + + @property + def keys_to_check(self): + for minute in range(self.minutes + 1): + when = self.now - datetime.timedelta(minutes=minute) + + yield "%s-%s" % (self.prefix, when.strftime("%Y-%m-%d-%H:%M")) + + async def get_counter(self): + """ + Returns the number of requests that have been done in + recent time. + """ + async with await self.backend.cache.pipeline() as p: + for key in self.keys_to_check: + await p.get(key) + + # Run the pipeline + res = await p.execute() + + # Return the sum + return sum((int(e) for e in res if e)) + + def write_headers(self, counter): + # Send the limit to the user + self.handler.set_header("X-Rate-Limit-Limit", self.limit) + + # Send the user how many requests are left for this time window + self.handler.set_header("X-Rate-Limit-Remaining", self.limit - counter) + + # Send when the limit resets + self.handler.set_header("X-Rate-Limit-Reset", self.expires_at.strftime("%s")) + + async def increment_counter(self): + async with await self.backend.cache.pipeline() as p: + # Increment the key + await p.incr(self.key) + + # Set expiry + await p.expireat(self.key, self.expires_at) + + # Run the pipeline + await p.execute() diff --git a/src/web/base.py b/src/web/base.py index db3a0f02..b7c96d96 100644 --- a/src/web/base.py +++ b/src/web/base.py @@ -1,6 +1,8 @@ #!/usr/bin/python +import asyncio import base64 +import functools import http.client import json import kerberos @@ -373,3 +375,36 @@ class APIMixin(KerberosAuthMixin, BackendMixin): log.debug("%s" % json.dumps(message, indent=4)) return message + + +class ratelimit(object): + """ + A decorator class which limits how often a function can be called + """ + def __init__(self, *, minutes, requests): + self.minutes = minutes + self.requests = requests + + def __call__(self, method): + @functools.wraps(method) + async def wrapper(handler, *args, **kwargs): + # Pass the request to the rate limiter and get a request object + req = handler.backend.ratelimiter.handle_request(handler.request, + handler, minutes=self.minutes, limit=self.requests) + + # If the rate limit has been reached, we won't allow + # processing the request and therefore send HTTP error code 429. + if await req.is_ratelimited(): + raise tornado.web.HTTPError(429, "Rate limit exceeded") + + # Call the wrapped method + result = method(handler, *args, **kwargs) + + # Await it if it is a coroutine + if asyncio.iscoroutine(result): + return await result + + # Return the result + return result + + return wrapper