From 24da53ef3b554b50c6786a974e0119ffdbdeb351 Mon Sep 17 00:00:00 2001 From: Michael Tremer Date: Tue, 21 Jan 2025 14:27:03 +0000 Subject: [PATCH] ratelimiter: Reimplement using the PostgreSQL database Redis does not seem the right choice for this. We can have a fast, unlogged database table for any of this data and we can drop the entire depdency to redis. Signed-off-by: Michael Tremer --- src/buildservice/__init__.py | 3 + src/buildservice/ratelimiter.py | 186 +++++++++++++++++++++++--------- src/database.sql | 21 ++++ 3 files changed, 158 insertions(+), 52 deletions(-) diff --git a/src/buildservice/__init__.py b/src/buildservice/__init__.py index 5aadd620..0eafdcf8 100644 --- a/src/buildservice/__init__.py +++ b/src/buildservice/__init__.py @@ -544,6 +544,9 @@ class Backend(object): # Messages await self.messages.queue.cleanup() + # Ratelimiter + await self.ratelimiter.cleanup() + # Sessions await self.sessions.cleanup() diff --git a/src/buildservice/ratelimiter.py b/src/buildservice/ratelimiter.py index 5cfd6fa1..965bb840 100644 --- a/src/buildservice/ratelimiter.py +++ b/src/buildservice/ratelimiter.py @@ -20,17 +20,63 @@ ############################################################################### import datetime +import ipaddress + +import sqlalchemy +from sqlalchemy import Column, DateTime, Integer, Text, UniqueConstraint +from sqlalchemy.dialects.postgresql import INET from . import base +from . import database + +ratelimiter = sqlalchemy.Table( + "ratelimiter", database.Base.metadata, + + # Key + Column("key", Text, nullable=False), + + # Timestamp + Column("timestamp", DateTime(timezone=False), nullable=False, + server_default=sqlalchemy.func.current_timestamp()), + + # Address + Column("address", INET, nullable=False), + + # Requests + Column("requests", Integer, nullable=False, default=1), + + # Expires At + Column("expires_at", DateTime(timezone=False), nullable=False), + + # Unique constraint + UniqueConstraint("key", "timestamp", "address", name="ratelimiter_unique") +) + class RateLimiter(base.Object): - def handle_request(self, request, handler, minutes, limit): + def handle_request(self, request, handler, minutes, limit, **kwargs): return RateLimiterRequest(self.backend, request, handler, - minutes=minutes, limit=limit) + minutes=minutes, limit=limit, **kwargs) + + async def cleanup(backend): + """ + Called to cleanup the ratelimiter from expired entries + """ + # Delete everything that has expired in the past + stmt = ( + ratelimiter + .delete() + .where( + ratelimiter.c.expired_at <= sqlalchemy.func.current_timestamp(), + ) + ) + + # Run the query + await self.db.execute(stmt) class RateLimiterRequest(base.Object): - def init(self, request, handler, minutes, limit): + def init(self, request, handler, minutes, limit, key=None): self.request = request self.handler = handler @@ -38,80 +84,116 @@ class RateLimiterRequest(base.Object): self.minutes = minutes self.limit = limit + # Create a default key if none given + if key is None: + key = "%s-%s-%s" % ( + self.request.host, + self.request.method, + self.request.path, + ) + + # Store the key and address + self.key = key + self.address = ipaddress.ip_address( + self.request.remote_ip, + ) + # What is the current time? self.now = datetime.datetime.utcnow() # When to expire? self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1) - self.prefix = "-".join(( - self.__class__.__name__, - self.request.host, - self.request.path, - self.request.method, - self.request.remote_ip, - )) - async def is_ratelimited(self): """ Returns True if the request is prohibited by the rate limiter """ - counter = await self.get_counter() + requests = await self.get_requests() # The client is rate-limited when more requests have been # received than allowed. - if counter >= self.limit: + if requests >= self.limit: return True - # Increment the counter - await self.increment_counter() + # Increment the request counter + await self.increment_requests() # If not ratelimited, write some headers - self.write_headers(counter=counter) - - @property - def key(self): - return "%s-%s" % (self.prefix, self.now.strftime("%Y-%m-%d-%H:%M")) + self.write_headers(requests=requests) - @property - def keys_to_check(self): - for minute in range(self.minutes + 1): - when = self.now - datetime.timedelta(minutes=minute) - - yield "%s-%s" % (self.prefix, when.strftime("%Y-%m-%d-%H:%M")) - - async def get_counter(self): + async def get_requests(self): """ - Returns the number of requests that have been done in - recent time. + Returns the number of requests that have been done in the recent sliding window """ - async with await self.backend.cache.pipeline() as p: - for key in self.keys_to_check: - await p.get(key) - - # Run the pipeline - res = await p.execute() - - # Return the sum - return sum((int(e) for e in res if e)) - - def write_headers(self, counter): + # Now, rounded down to the minute + now = sqlalchemy.func.date_trunc( + "minute", sqlalchemy.func.current_timestamp(), + ) + + # Go back into the past to see when the sliding window has started + since = now - datetime.timedelta(minutes=self.minutes) + + # Sum up all requests + stmt = ( + sqlalchemy + .select( + sqlalchemy.func.sum( + ratelimiter.c.requests, + ).label("requests") + ) + .where( + ratelimiter.c.key == self.key, + ratelimiter.c.timestamp >= since, + ratelimiter.c.address == self.address, + ) + ) + + return await self.db.select_one(stmt, "requests") or 0 + + def write_headers(self, requests): # Send the limit to the user self.handler.set_header("X-Rate-Limit-Limit", self.limit) # Send the user how many requests are left for this time window - self.handler.set_header("X-Rate-Limit-Remaining", self.limit - counter) + self.handler.set_header("X-Rate-Limit-Remaining", self.limit - requests) # Send when the limit resets self.handler.set_header("X-Rate-Limit-Reset", self.expires_at.strftime("%s")) - async def increment_counter(self): - async with await self.backend.cache.pipeline() as p: - # Increment the key - await p.incr(self.key) - - # Set expiry - await p.expireat(self.key, self.expires_at) - - # Run the pipeline - await p.execute() + async def increment_requests(self): + """ + Increments the counter that identifies this request + """ + now = sqlalchemy.func.date_trunc( + "minute", sqlalchemy.func.current_timestamp(), + ) + + # Figure out until when we will need this entry + expires_at = now + datetime.timedelta(minutes=self.minutes + 1) + + # Create a new entry to the database + insert_stmt = ( + sqlalchemy.dialects.postgresql + .insert( + ratelimiter, + ) + .values({ + "key" : self.key, + "timestamp" : now, + "address" : self.address, + "requests" : 1, + "expires_at" : expires_at, + }) + ) + + # If the entry exist already, we just increment the counter + upsert_stmt = insert_stmt.on_conflict_do_update( + index_elements = [ + "key", "timestamp", "address", + ], + set_ = { + "requests" : ratelimiter.c.requests + 1 + }, + ) + + await self.db.execute(upsert_stmt) diff --git a/src/database.sql b/src/database.sql index 0186a24b..b7ca0d61 100644 --- a/src/database.sql +++ b/src/database.sql @@ -642,6 +642,19 @@ CREATE SEQUENCE public.packages_id_seq ALTER SEQUENCE public.packages_id_seq OWNED BY public.packages.id; +-- +-- Name: ratelimiter; Type: TABLE; Schema: public; Owner: - +-- + +CREATE UNLOGGED TABLE public.ratelimiter ( + key text NOT NULL, + "timestamp" timestamp without time zone DEFAULT CURRENT_TIMESTAMP NOT NULL, + address inet NOT NULL, + requests integer DEFAULT 1 NOT NULL, + expires_at timestamp without time zone NOT NULL +); + + -- -- Name: relation_sizes; Type: VIEW; Schema: public; Owner: - -- @@ -1421,6 +1434,14 @@ ALTER TABLE ONLY public.packages ADD CONSTRAINT packages_pkey PRIMARY KEY (id); +-- +-- Name: ratelimiter ratelimiter_unique; Type: CONSTRAINT; Schema: public; Owner: - +-- + +ALTER TABLE ONLY public.ratelimiter + ADD CONSTRAINT ratelimiter_unique UNIQUE (key, "timestamp", address); + + -- -- Name: release_images release_images_pkey; Type: CONSTRAINT; Schema: public; Owner: - -- -- 2.47.3