]> git.ipfire.org Git - pbs.git/commitdiff
web: Create a simple rate limiter
authorMichael Tremer <michael.tremer@ipfire.org>
Fri, 11 Aug 2023 14:45:56 +0000 (14:45 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Fri, 11 Aug 2023 14:45:56 +0000 (14:45 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
Makefile.am
src/buildservice/__init__.py
src/buildservice/cache.py
src/buildservice/ratelimiter.py [new file with mode: 0644]
src/web/base.py

index 1b0031c75371b7302bd81f2cf4035962c096ab55..eb6f05342ef269c7c97594a7200953b8fa7ab16a 100644 (file)
@@ -102,6 +102,7 @@ buildservice_PYTHON = \
        src/buildservice/mirrors.py \
        src/buildservice/misc.py \
        src/buildservice/packages.py \
+       src/buildservice/ratelimiter.py \
        src/buildservice/releasemonitoring.py \
        src/buildservice/repository.py \
        src/buildservice/sessions.py \
index b69493fcf420b5ad9084b02de716da8eb7906f86..bc5aa31a2fa62d744dfcfae859dbc265065d7203 100644 (file)
@@ -29,6 +29,7 @@ from . import logstreams
 from . import messages
 from . import mirrors
 from . import packages
+from . import ratelimiter
 from . import releasemonitoring
 from . import repository
 from . import settings
@@ -80,6 +81,7 @@ class Backend(object):
                self.mirrors     = mirrors.Mirrors(self)
                self.packages    = packages.Packages(self)
                self.monitorings = releasemonitoring.Monitorings(self)
+               self.ratelimiter = ratelimiter.RateLimiter(self)
                self.repos       = repository.Repositories(self)
                self.sessions    = sessions.Sessions(self)
                self.sources     = sources.Sources(self)
index 0b70c4f296118212d3694433f7b5fefeca596992..a59b9af8d6c366d9c14217553af75e0783c863cc 100644 (file)
@@ -115,3 +115,13 @@ class Cache(object):
                conn = await self.connection()
 
                return await conn.delete(key)
+
+       async def transaction(self, *args, **kwargs):
+               conn = await self.connection()
+
+               return await conn.transaction(*args, **kwargs)
+
+       async def pipeline(self, *args, **kwargs):
+               conn = await self.connection()
+
+               return conn.pipeline(*args, **kwargs)
diff --git a/src/buildservice/ratelimiter.py b/src/buildservice/ratelimiter.py
new file mode 100644 (file)
index 0000000..5cfd6fa
--- /dev/null
@@ -0,0 +1,117 @@
+#!/usr/bin/python3
+###############################################################################
+#                                                                             #
+# Pakfire - The IPFire package management system                              #
+# Copyright (C) 2022 Pakfire development team                                 #
+#                                                                             #
+# This program is free software: you can redistribute it and/or modify        #
+# it under the terms of the GNU General Public License as published by        #
+# the Free Software Foundation, either version 3 of the License, or           #
+# (at your option) any later version.                                         #
+#                                                                             #
+# This program is distributed in the hope that it will be useful,             #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
+# GNU General Public License for more details.                                #
+#                                                                             #
+# You should have received a copy of the GNU General Public License           #
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
+#                                                                             #
+###############################################################################
+
+import datetime
+
+from . import base
+
+class RateLimiter(base.Object):
+       def handle_request(self, request, handler, minutes, limit):
+               return RateLimiterRequest(self.backend, request, handler,
+                       minutes=minutes, limit=limit)
+
+
+class RateLimiterRequest(base.Object):
+       def init(self, request, handler, minutes, limit):
+               self.request = request
+               self.handler = handler
+
+               # Save the limits
+               self.minutes = minutes
+               self.limit   = limit
+
+               # What is the current time?
+               self.now = datetime.datetime.utcnow()
+
+               # When to expire?
+               self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1)
+
+               self.prefix = "-".join((
+                       self.__class__.__name__,
+                       self.request.host,
+                       self.request.path,
+                       self.request.method,
+                       self.request.remote_ip,
+               ))
+
+       async def is_ratelimited(self):
+               """
+                       Returns True if the request is prohibited by the rate limiter
+               """
+               counter = await self.get_counter()
+
+               # The client is rate-limited when more requests have been
+               # received than allowed.
+               if counter >= self.limit:
+                       return True
+
+               # Increment the counter
+               await self.increment_counter()
+
+               # If not ratelimited, write some headers
+               self.write_headers(counter=counter)
+
+       @property
+       def key(self):
+               return "%s-%s" % (self.prefix, self.now.strftime("%Y-%m-%d-%H:%M"))
+
+       @property
+       def keys_to_check(self):
+               for minute in range(self.minutes + 1):
+                       when = self.now - datetime.timedelta(minutes=minute)
+
+                       yield "%s-%s" % (self.prefix, when.strftime("%Y-%m-%d-%H:%M"))
+
+       async def get_counter(self):
+               """
+                       Returns the number of requests that have been done in
+                       recent time.
+               """
+               async with await self.backend.cache.pipeline() as p:
+                       for key in self.keys_to_check:
+                               await p.get(key)
+
+                       # Run the pipeline
+                       res = await p.execute()
+
+               # Return the sum
+               return sum((int(e) for e in res if e))
+
+       def write_headers(self, counter):
+               # Send the limit to the user
+               self.handler.set_header("X-Rate-Limit-Limit", self.limit)
+
+               # Send the user how many requests are left for this time window
+               self.handler.set_header("X-Rate-Limit-Remaining", self.limit - counter)
+
+               # Send when the limit resets
+               self.handler.set_header("X-Rate-Limit-Reset", self.expires_at.strftime("%s"))
+
+       async def increment_counter(self):
+               async with await self.backend.cache.pipeline() as p:
+                       # Increment the key
+                       await p.incr(self.key)
+
+                       # Set expiry
+                       await p.expireat(self.key, self.expires_at)
+
+                       # Run the pipeline
+                       await p.execute()
index db3a0f02fd42a86e50fd4ae774cca7b680e75b3f..b7c96d96c1fec9025c6904dd145092b58df8a46b 100644 (file)
@@ -1,6 +1,8 @@
 #!/usr/bin/python
 
+import asyncio
 import base64
+import functools
 import http.client
 import json
 import kerberos
@@ -373,3 +375,36 @@ class APIMixin(KerberosAuthMixin, BackendMixin):
                log.debug("%s" % json.dumps(message, indent=4))
 
                return message
+
+
+class ratelimit(object):
+       """
+               A decorator class which limits how often a function can be called
+       """
+       def __init__(self, *, minutes, requests):
+               self.minutes  = minutes
+               self.requests = requests
+
+       def __call__(self, method):
+               @functools.wraps(method)
+               async def wrapper(handler, *args, **kwargs):
+                       # Pass the request to the rate limiter and get a request object
+                       req = handler.backend.ratelimiter.handle_request(handler.request,
+                               handler, minutes=self.minutes, limit=self.requests)
+
+                       # If the rate limit has been reached, we won't allow
+                       # processing the request and therefore send HTTP error code 429.
+                       if await req.is_ratelimited():
+                               raise tornado.web.HTTPError(429, "Rate limit exceeded")
+
+                       # Call the wrapped method
+                       result = method(handler, *args, **kwargs)
+
+                       # Await it if it is a coroutine
+                       if asyncio.iscoroutine(result):
+                               return await result
+
+                       # Return the result
+                       return result
+
+               return wrapper