###############################################################################
import datetime
+import ipaddress
+
+import sqlalchemy
+from sqlalchemy import Column, DateTime, Integer, Text, UniqueConstraint
+from sqlalchemy.dialects.postgresql import INET
from . import base
+from . import database
+
+ratelimiter = sqlalchemy.Table(
+ "ratelimiter", database.Base.metadata,
+
+ # Key
+ Column("key", Text, nullable=False),
+
+ # Timestamp
+ Column("timestamp", DateTime(timezone=False), nullable=False,
+ server_default=sqlalchemy.func.current_timestamp()),
+
+ # Address
+ Column("address", INET, nullable=False),
+
+ # Requests
+ Column("requests", Integer, nullable=False, default=1),
+
+ # Expires At
+ Column("expires_at", DateTime(timezone=False), nullable=False),
+
+ # Unique constraint
+ UniqueConstraint("key", "timestamp", "address", name="ratelimiter_unique")
+)
+
class RateLimiter(base.Object):
- def handle_request(self, request, handler, minutes, limit):
+ def handle_request(self, request, handler, minutes, limit, **kwargs):
return RateLimiterRequest(self.backend, request, handler,
- minutes=minutes, limit=limit)
+ minutes=minutes, limit=limit, **kwargs)
+
+ async def cleanup(backend):
+ """
+ Called to cleanup the ratelimiter from expired entries
+ """
+ # Delete everything that has expired in the past
+ stmt = (
+ ratelimiter
+ .delete()
+ .where(
+ ratelimiter.c.expired_at <= sqlalchemy.func.current_timestamp(),
+ )
+ )
+
+ # Run the query
+ await self.db.execute(stmt)
class RateLimiterRequest(base.Object):
- def init(self, request, handler, minutes, limit):
+ def init(self, request, handler, minutes, limit, key=None):
self.request = request
self.handler = handler
self.minutes = minutes
self.limit = limit
+ # Create a default key if none given
+ if key is None:
+ key = "%s-%s-%s" % (
+ self.request.host,
+ self.request.method,
+ self.request.path,
+ )
+
+ # Store the key and address
+ self.key = key
+ self.address = ipaddress.ip_address(
+ self.request.remote_ip,
+ )
+
# What is the current time?
self.now = datetime.datetime.utcnow()
# When to expire?
self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1)
- self.prefix = "-".join((
- self.__class__.__name__,
- self.request.host,
- self.request.path,
- self.request.method,
- self.request.remote_ip,
- ))
-
async def is_ratelimited(self):
"""
Returns True if the request is prohibited by the rate limiter
"""
- counter = await self.get_counter()
+ requests = await self.get_requests()
# The client is rate-limited when more requests have been
# received than allowed.
- if counter >= self.limit:
+ if requests >= self.limit:
return True
- # Increment the counter
- await self.increment_counter()
+ # Increment the request counter
+ await self.increment_requests()
# If not ratelimited, write some headers
- self.write_headers(counter=counter)
-
- @property
- def key(self):
- return "%s-%s" % (self.prefix, self.now.strftime("%Y-%m-%d-%H:%M"))
+ self.write_headers(requests=requests)
- @property
- def keys_to_check(self):
- for minute in range(self.minutes + 1):
- when = self.now - datetime.timedelta(minutes=minute)
-
- yield "%s-%s" % (self.prefix, when.strftime("%Y-%m-%d-%H:%M"))
-
- async def get_counter(self):
+ async def get_requests(self):
"""
- Returns the number of requests that have been done in
- recent time.
+ Returns the number of requests that have been done in the recent sliding window
"""
- async with await self.backend.cache.pipeline() as p:
- for key in self.keys_to_check:
- await p.get(key)
-
- # Run the pipeline
- res = await p.execute()
-
- # Return the sum
- return sum((int(e) for e in res if e))
-
- def write_headers(self, counter):
+ # Now, rounded down to the minute
+ now = sqlalchemy.func.date_trunc(
+ "minute", sqlalchemy.func.current_timestamp(),
+ )
+
+ # Go back into the past to see when the sliding window has started
+ since = now - datetime.timedelta(minutes=self.minutes)
+
+ # Sum up all requests
+ stmt = (
+ sqlalchemy
+ .select(
+ sqlalchemy.func.sum(
+ ratelimiter.c.requests,
+ ).label("requests")
+ )
+ .where(
+ ratelimiter.c.key == self.key,
+ ratelimiter.c.timestamp >= since,
+ ratelimiter.c.address == self.address,
+ )
+ )
+
+ return await self.db.select_one(stmt, "requests") or 0
+
+ def write_headers(self, requests):
# Send the limit to the user
self.handler.set_header("X-Rate-Limit-Limit", self.limit)
# Send the user how many requests are left for this time window
- self.handler.set_header("X-Rate-Limit-Remaining", self.limit - counter)
+ self.handler.set_header("X-Rate-Limit-Remaining", self.limit - requests)
# Send when the limit resets
self.handler.set_header("X-Rate-Limit-Reset", self.expires_at.strftime("%s"))
- async def increment_counter(self):
- async with await self.backend.cache.pipeline() as p:
- # Increment the key
- await p.incr(self.key)
-
- # Set expiry
- await p.expireat(self.key, self.expires_at)
-
- # Run the pipeline
- await p.execute()
+ async def increment_requests(self):
+ """
+ Increments the counter that identifies this request
+ """
+ now = sqlalchemy.func.date_trunc(
+ "minute", sqlalchemy.func.current_timestamp(),
+ )
+
+ # Figure out until when we will need this entry
+ expires_at = now + datetime.timedelta(minutes=self.minutes + 1)
+
+ # Create a new entry to the database
+ insert_stmt = (
+ sqlalchemy.dialects.postgresql
+ .insert(
+ ratelimiter,
+ )
+ .values({
+ "key" : self.key,
+ "timestamp" : now,
+ "address" : self.address,
+ "requests" : 1,
+ "expires_at" : expires_at,
+ })
+ )
+
+ # If the entry exist already, we just increment the counter
+ upsert_stmt = insert_stmt.on_conflict_do_update(
+ index_elements = [
+ "key", "timestamp", "address",
+ ],
+ set_ = {
+ "requests" : ratelimiter.c.requests + 1
+ },
+ )
+
+ await self.db.execute(upsert_stmt)