From: Michael Tremer Date: Wed, 25 Oct 2023 15:12:43 +0000 (+0000) Subject: ratelimiter: Migrate from memcache to redis X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=0056fdbf452b26cc957bd1064fb7744d20ad9067;p=ipfire.org.git ratelimiter: Migrate from memcache to redis Signed-off-by: Michael Tremer --- diff --git a/src/backend/ratelimit.py b/src/backend/ratelimit.py index ec99cc51..f4b2b4c1 100644 --- a/src/backend/ratelimit.py +++ b/src/backend/ratelimit.py @@ -11,8 +11,6 @@ class RateLimiter(misc.Object): class RateLimiterRequest(misc.Object): - prefix = "ratelimit" - def init(self, request, handler, minutes, limit): self.request = request self.handler = handler @@ -21,81 +19,80 @@ class RateLimiterRequest(misc.Object): self.minutes = minutes self.limit = limit + # What is the current time? self.now = datetime.datetime.utcnow() - # Fetch the current counter value from the cache - self.counter = self.get_counter() - - # Increment the rate-limiting counter - self.increment_counter() + # When to expire? + self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1) - # Write the header if we are not limited - if not self.is_ratelimited(): - self.write_headers() + self.prefix = "-".join(( + self.__class__.__name__, + self.request.host, + self.request.path, + self.request.method, + self.request.remote_ip, + )) - def is_ratelimited(self): + async def is_ratelimited(self): """ Returns True if the request is prohibited by the rate limiter """ + counter = await self.get_counter() + # The client is rate-limited when more requests have been # received than allowed. - return self.counter >= self.limit + if counter >= self.limit: + return True + + # Increment the counter + await self.increment_counter() + + # If not ratelimited, write some headers + self.write_headers(counter=counter) + + @property + def key(self): + return "%s-%s" % (self.prefix, self.now.strftime("%Y-%m-%d-%H:%M")) + + @property + def keys_to_check(self): + for minute in range(self.minutes + 1): + when = self.now - datetime.timedelta(minutes=minute) + + yield "%s-%s" % (self.prefix, when.strftime("%Y-%m-%d-%H:%M")) - def get_counter(self): + async def get_counter(self): """ Returns the number of requests that have been done in recent time. """ - keys = self.get_keys_to_check() + async with await self.backend.cache.pipeline() as p: + for key in self.keys_to_check: + await p.get(key) - res = self.memcache.get_multi(keys) - if res: - return sum((int(e) for e in res.values())) + # Run the pipeline + res = await p.execute() - return 0 + # Return the sum + return sum((int(e) for e in res if e)) - def write_headers(self): + def write_headers(self, counter): # Send the limit to the user self.handler.set_header("X-Rate-Limit-Limit", self.limit) # Send the user how many requests are left for this time window - self.handler.set_header("X-Rate-Limit-Remaining", - self.limit - self.counter) + self.handler.set_header("X-Rate-Limit-Remaining", self.limit - counter) - expires = self.now + datetime.timedelta(seconds=self.expires_after) - self.handler.set_header("X-Rate-Limit-Reset", expires.strftime("%s")) + # Send when the limit resets + self.handler.set_header("X-Rate-Limit-Reset", self.expires_at.strftime("%s")) - def get_key(self): - key_prefix = self.get_key_prefix() + async def increment_counter(self): + async with await self.backend.cache.pipeline() as p: + # Increment the key + await p.incr(self.key) - return "%s-%s" % (key_prefix, self.now.strftime("%Y-%m-%d-%H:%M")) + # Set expiry + await p.expireat(self.key, self.expires_at) - def get_keys_to_check(self): - key_prefix = self.get_key_prefix() - - keys = [] - for minute in range(self.minutes + 1): - when = self.now - datetime.timedelta(minutes=minute) - - key = "%s-%s" % (key_prefix, when.strftime("%Y-%m-%d-%H:%M")) - keys.append(key) - - return keys - - def get_key_prefix(self): - return "-".join((self.prefix, self.request.host, self.request.path, - self.request.method, self.request.remote_ip,)) - - def increment_counter(self): - key = self.get_key() - - # Add the key or increment if it already exists - if not self.memcache.add(key, "1", self.expires_after): - self.memcache.incr(key) - - @property - def expires_after(self): - """ - Returns the number of seconds after which the counter has reset. - """ - return (self.minutes + 1) * 60 + # Run the pipeline + await p.execute() diff --git a/src/web/base.py b/src/web/base.py index ad217517..ce7202c4 100644 --- a/src/web/base.py +++ b/src/web/base.py @@ -1,5 +1,6 @@ #!/usr/bin/python +import asyncio import datetime import dateutil.parser import functools @@ -14,23 +15,34 @@ from ..decorators import * from .. import util class ratelimit(object): - def __init__(self, minutes=15, requests=180): - self.minutes = minutes + """ + A decorator class which limits how often a function can be called + """ + def __init__(self, *, minutes, requests): + self.minutes = minutes self.requests = requests def __call__(self, method): @functools.wraps(method) - def wrapper(handler, *args, **kwargs): + async def wrapper(handler, *args, **kwargs): # Pass the request to the rate limiter and get a request object req = handler.backend.ratelimiter.handle_request(handler.request, handler, minutes=self.minutes, limit=self.requests) # If the rate limit has been reached, we won't allow # processing the request and therefore send HTTP error code 429. - if req.is_ratelimited(): + if await req.is_ratelimited(): raise tornado.web.HTTPError(429, "Rate limit exceeded") - return method(handler, *args, **kwargs) + # Call the wrapped method + result = method(handler, *args, **kwargs) + + # Await it if it is a coroutine + if asyncio.iscoroutine(result): + return await result + + # Return the result + return result return wrapper