class RateLimiterRequest(misc.Object):
- prefix = "ratelimit"
-
def init(self, request, handler, minutes, limit):
self.request = request
self.handler = handler
self.minutes = minutes
self.limit = limit
+ # What is the current time?
self.now = datetime.datetime.utcnow()
- # Fetch the current counter value from the cache
- self.counter = self.get_counter()
-
- # Increment the rate-limiting counter
- self.increment_counter()
+ # When to expire?
+ self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1)
- # Write the header if we are not limited
- if not self.is_ratelimited():
- self.write_headers()
+ self.prefix = "-".join((
+ self.__class__.__name__,
+ self.request.host,
+ self.request.path,
+ self.request.method,
+ self.request.remote_ip,
+ ))
- def is_ratelimited(self):
+ async def is_ratelimited(self):
"""
Returns True if the request is prohibited by the rate limiter
"""
+ counter = await self.get_counter()
+
# The client is rate-limited when more requests have been
# received than allowed.
- return self.counter >= self.limit
+ if counter >= self.limit:
+ return True
+
+ # Increment the counter
+ await self.increment_counter()
+
+ # If not ratelimited, write some headers
+ self.write_headers(counter=counter)
+
+ @property
+ def key(self):
+ return "%s-%s" % (self.prefix, self.now.strftime("%Y-%m-%d-%H:%M"))
+
+ @property
+ def keys_to_check(self):
+ for minute in range(self.minutes + 1):
+ when = self.now - datetime.timedelta(minutes=minute)
+
+ yield "%s-%s" % (self.prefix, when.strftime("%Y-%m-%d-%H:%M"))
- def get_counter(self):
+ async def get_counter(self):
"""
Returns the number of requests that have been done in
recent time.
"""
- keys = self.get_keys_to_check()
+ async with await self.backend.cache.pipeline() as p:
+ for key in self.keys_to_check:
+ await p.get(key)
- res = self.memcache.get_multi(keys)
- if res:
- return sum((int(e) for e in res.values()))
+ # Run the pipeline
+ res = await p.execute()
- return 0
+ # Return the sum
+ return sum((int(e) for e in res if e))
- def write_headers(self):
+ def write_headers(self, counter):
# Send the limit to the user
self.handler.set_header("X-Rate-Limit-Limit", self.limit)
# Send the user how many requests are left for this time window
- self.handler.set_header("X-Rate-Limit-Remaining",
- self.limit - self.counter)
+ self.handler.set_header("X-Rate-Limit-Remaining", self.limit - counter)
- expires = self.now + datetime.timedelta(seconds=self.expires_after)
- self.handler.set_header("X-Rate-Limit-Reset", expires.strftime("%s"))
+ # Send when the limit resets
+ self.handler.set_header("X-Rate-Limit-Reset", self.expires_at.strftime("%s"))
- def get_key(self):
- key_prefix = self.get_key_prefix()
+ async def increment_counter(self):
+ async with await self.backend.cache.pipeline() as p:
+ # Increment the key
+ await p.incr(self.key)
- return "%s-%s" % (key_prefix, self.now.strftime("%Y-%m-%d-%H:%M"))
+ # Set expiry
+ await p.expireat(self.key, self.expires_at)
- def get_keys_to_check(self):
- key_prefix = self.get_key_prefix()
-
- keys = []
- for minute in range(self.minutes + 1):
- when = self.now - datetime.timedelta(minutes=minute)
-
- key = "%s-%s" % (key_prefix, when.strftime("%Y-%m-%d-%H:%M"))
- keys.append(key)
-
- return keys
-
- def get_key_prefix(self):
- return "-".join((self.prefix, self.request.host, self.request.path,
- self.request.method, self.request.remote_ip,))
-
- def increment_counter(self):
- key = self.get_key()
-
- # Add the key or increment if it already exists
- if not self.memcache.add(key, "1", self.expires_after):
- self.memcache.incr(key)
-
- @property
- def expires_after(self):
- """
- Returns the number of seconds after which the counter has reset.
- """
- return (self.minutes + 1) * 60
+ # Run the pipeline
+ await p.execute()
#!/usr/bin/python
+import asyncio
import datetime
import dateutil.parser
import functools
from .. import util
class ratelimit(object):
- def __init__(self, minutes=15, requests=180):
- self.minutes = minutes
+ """
+ A decorator class which limits how often a function can be called
+ """
+ def __init__(self, *, minutes, requests):
+ self.minutes = minutes
self.requests = requests
def __call__(self, method):
@functools.wraps(method)
- def wrapper(handler, *args, **kwargs):
+ async def wrapper(handler, *args, **kwargs):
# Pass the request to the rate limiter and get a request object
req = handler.backend.ratelimiter.handle_request(handler.request,
handler, minutes=self.minutes, limit=self.requests)
# If the rate limit has been reached, we won't allow
# processing the request and therefore send HTTP error code 429.
- if req.is_ratelimited():
+ if await req.is_ratelimited():
raise tornado.web.HTTPError(429, "Rate limit exceeded")
- return method(handler, *args, **kwargs)
+ # Call the wrapped method
+ result = method(handler, *args, **kwargs)
+
+ # Await it if it is a coroutine
+ if asyncio.iscoroutine(result):
+ return await result
+
+ # Return the result
+ return result
return wrapper