]> git.ipfire.org Git - ipfire.org.git/commitdiff
ratelimiter: Migrate from memcache to redis
authorMichael Tremer <michael.tremer@ipfire.org>
Wed, 25 Oct 2023 15:12:43 +0000 (15:12 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Wed, 25 Oct 2023 15:12:43 +0000 (15:12 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/backend/ratelimit.py
src/web/base.py

index ec99cc511ac7d4068efa42c406002a4050c0c57b..f4b2b4c129ec80aed31a1982097a5b65dd59ad20 100644 (file)
@@ -11,8 +11,6 @@ class RateLimiter(misc.Object):
 
 
 class RateLimiterRequest(misc.Object):
-       prefix = "ratelimit"
-
        def init(self, request, handler, minutes, limit):
                self.request = request
                self.handler = handler
@@ -21,81 +19,80 @@ class RateLimiterRequest(misc.Object):
                self.minutes = minutes
                self.limit   = limit
 
+               # What is the current time?
                self.now = datetime.datetime.utcnow()
 
-               # Fetch the current counter value from the cache
-               self.counter = self.get_counter()
-
-               # Increment the rate-limiting counter
-               self.increment_counter()
+               # When to expire?
+               self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1)
 
-               # Write the header if we are not limited
-               if not self.is_ratelimited():
-                       self.write_headers()
+               self.prefix = "-".join((
+                       self.__class__.__name__,
+                       self.request.host,
+                       self.request.path,
+                       self.request.method,
+                       self.request.remote_ip,
+               ))
 
-       def is_ratelimited(self):
+       async def is_ratelimited(self):
                """
                        Returns True if the request is prohibited by the rate limiter
                """
+               counter = await self.get_counter()
+
                # The client is rate-limited when more requests have been
                # received than allowed.
-               return self.counter >= self.limit
+               if counter >= self.limit:
+                       return True
+
+               # Increment the counter
+               await self.increment_counter()
+
+               # If not ratelimited, write some headers
+               self.write_headers(counter=counter)
+
+       @property
+       def key(self):
+               return "%s-%s" % (self.prefix, self.now.strftime("%Y-%m-%d-%H:%M"))
+
+       @property
+       def keys_to_check(self):
+               for minute in range(self.minutes + 1):
+                       when = self.now - datetime.timedelta(minutes=minute)
+
+                       yield "%s-%s" % (self.prefix, when.strftime("%Y-%m-%d-%H:%M"))
 
-       def get_counter(self):
+       async def get_counter(self):
                """
                        Returns the number of requests that have been done in
                        recent time.
                """
-               keys = self.get_keys_to_check()
+               async with await self.backend.cache.pipeline() as p:
+                       for key in self.keys_to_check:
+                               await p.get(key)
 
-               res = self.memcache.get_multi(keys)
-               if res:
-                       return sum((int(e) for e in res.values()))
+                       # Run the pipeline
+                       res = await p.execute()
 
-               return 0
+               # Return the sum
+               return sum((int(e) for e in res if e))
 
-       def write_headers(self):
+       def write_headers(self, counter):
                # Send the limit to the user
                self.handler.set_header("X-Rate-Limit-Limit", self.limit)
 
                # Send the user how many requests are left for this time window
-               self.handler.set_header("X-Rate-Limit-Remaining",
-                       self.limit - self.counter)
+               self.handler.set_header("X-Rate-Limit-Remaining", self.limit - counter)
 
-               expires = self.now + datetime.timedelta(seconds=self.expires_after)
-               self.handler.set_header("X-Rate-Limit-Reset", expires.strftime("%s"))
+               # Send when the limit resets
+               self.handler.set_header("X-Rate-Limit-Reset", self.expires_at.strftime("%s"))
 
-       def get_key(self):
-               key_prefix = self.get_key_prefix()
+       async def increment_counter(self):
+               async with await self.backend.cache.pipeline() as p:
+                       # Increment the key
+                       await p.incr(self.key)
 
-               return "%s-%s" % (key_prefix, self.now.strftime("%Y-%m-%d-%H:%M"))
+                       # Set expiry
+                       await p.expireat(self.key, self.expires_at)
 
-       def get_keys_to_check(self):
-               key_prefix = self.get_key_prefix()
-
-               keys = []
-               for minute in range(self.minutes + 1):
-                       when = self.now - datetime.timedelta(minutes=minute)
-
-                       key = "%s-%s" % (key_prefix, when.strftime("%Y-%m-%d-%H:%M"))
-                       keys.append(key)
-
-               return keys
-
-       def get_key_prefix(self):
-               return "-".join((self.prefix, self.request.host, self.request.path,
-                       self.request.method, self.request.remote_ip,))
-
-       def increment_counter(self):
-               key = self.get_key()
-
-               # Add the key or increment if it already exists
-               if not self.memcache.add(key, "1", self.expires_after):
-                       self.memcache.incr(key)
-
-       @property
-       def expires_after(self):
-               """
-                       Returns the number of seconds after which the counter has reset.
-               """
-               return (self.minutes + 1) * 60
+                       # Run the pipeline
+                       await p.execute()
index ad2175173a6287bf18b42a2f7e9db01dda5eef41..ce7202c4e95601d77658b33ce0cde908afe5b6b6 100644 (file)
@@ -1,5 +1,6 @@
 #!/usr/bin/python
 
+import asyncio
 import datetime
 import dateutil.parser
 import functools
@@ -14,23 +15,34 @@ from ..decorators import *
 from .. import util
 
 class ratelimit(object):
-       def __init__(self, minutes=15, requests=180):
-               self.minutes = minutes
+       """
+               A decorator class which limits how often a function can be called
+       """
+       def __init__(self, *, minutes, requests):
+               self.minutes  = minutes
                self.requests = requests
 
        def __call__(self, method):
                @functools.wraps(method)
-               def wrapper(handler, *args, **kwargs):
+               async def wrapper(handler, *args, **kwargs):
                        # Pass the request to the rate limiter and get a request object
                        req = handler.backend.ratelimiter.handle_request(handler.request,
                                handler, minutes=self.minutes, limit=self.requests)
 
                        # If the rate limit has been reached, we won't allow
                        # processing the request and therefore send HTTP error code 429.
-                       if req.is_ratelimited():
+                       if await req.is_ratelimited():
                                raise tornado.web.HTTPError(429, "Rate limit exceeded")
 
-                       return method(handler, *args, **kwargs)
+                       # Call the wrapped method
+                       result = method(handler, *args, **kwargs)
+
+                       # Await it if it is a coroutine
+                       if asyncio.iscoroutine(result):
+                               return await result
+
+                       # Return the result
+                       return result
 
                return wrapper