]> git.ipfire.org Git - pbs.git/commitdiff
ratelimiter: Reimplement using the PostgreSQL database
authorMichael Tremer <michael.tremer@ipfire.org>
Tue, 21 Jan 2025 14:27:03 +0000 (14:27 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Tue, 21 Jan 2025 14:27:03 +0000 (14:27 +0000)
Redis does not seem the right choice for this. We can have a fast,
unlogged database table for any of this data and we can drop the entire
depdency to redis.

Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/buildservice/__init__.py
src/buildservice/ratelimiter.py
src/database.sql

index 5aadd6206ee19f3293acfeef9653e9af4fd13975..0eafdcf860962e064f8f5a6ee7c5b853a2dbd6f7 100644 (file)
@@ -544,6 +544,9 @@ class Backend(object):
                # Messages
                await self.messages.queue.cleanup()
 
+               # Ratelimiter
+               await self.ratelimiter.cleanup()
+
                # Sessions
                await self.sessions.cleanup()
 
index 5cfd6fa15e94193d3dc9abb88922c430cff6569f..965bb84073078b97b53fc961c20646a25fb377fb 100644 (file)
 ###############################################################################
 
 import datetime
+import ipaddress
+
+import sqlalchemy
+from sqlalchemy import Column, DateTime, Integer, Text, UniqueConstraint
+from sqlalchemy.dialects.postgresql import INET
 
 from . import base
+from . import database
+
+ratelimiter = sqlalchemy.Table(
+       "ratelimiter", database.Base.metadata,
+
+       # Key
+       Column("key", Text, nullable=False),
+
+       # Timestamp
+       Column("timestamp", DateTime(timezone=False), nullable=False,
+               server_default=sqlalchemy.func.current_timestamp()),
+
+       # Address
+       Column("address", INET, nullable=False),
+
+       # Requests
+       Column("requests", Integer, nullable=False, default=1),
+
+       # Expires At
+       Column("expires_at", DateTime(timezone=False), nullable=False),
+
+       # Unique constraint
+       UniqueConstraint("key", "timestamp", "address", name="ratelimiter_unique")
+)
+
 
 class RateLimiter(base.Object):
-       def handle_request(self, request, handler, minutes, limit):
+       def handle_request(self, request, handler, minutes, limit, **kwargs):
                return RateLimiterRequest(self.backend, request, handler,
-                       minutes=minutes, limit=limit)
+                       minutes=minutes, limit=limit, **kwargs)
+
+       async def cleanup(backend):
+               """
+                       Called to cleanup the ratelimiter from expired entries
+               """
+               # Delete everything that has expired in the past
+               stmt = (
+                       ratelimiter
+                       .delete()
+                       .where(
+                               ratelimiter.c.expired_at <= sqlalchemy.func.current_timestamp(),
+                       )
+               )
+
+               # Run the query
+               await self.db.execute(stmt)
 
 
 class RateLimiterRequest(base.Object):
-       def init(self, request, handler, minutes, limit):
+       def init(self, request, handler, minutes, limit, key=None):
                self.request = request
                self.handler = handler
 
@@ -38,80 +84,116 @@ class RateLimiterRequest(base.Object):
                self.minutes = minutes
                self.limit   = limit
 
+               # Create a default key if none given
+               if key is None:
+                       key = "%s-%s-%s" % (
+                               self.request.host,
+                               self.request.method,
+                               self.request.path,
+                       )
+
+               # Store the key and address
+               self.key        = key
+               self.address    = ipaddress.ip_address(
+                       self.request.remote_ip,
+               )
+
                # What is the current time?
                self.now = datetime.datetime.utcnow()
 
                # When to expire?
                self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1)
 
-               self.prefix = "-".join((
-                       self.__class__.__name__,
-                       self.request.host,
-                       self.request.path,
-                       self.request.method,
-                       self.request.remote_ip,
-               ))
-
        async def is_ratelimited(self):
                """
                        Returns True if the request is prohibited by the rate limiter
                """
-               counter = await self.get_counter()
+               requests = await self.get_requests()
 
                # The client is rate-limited when more requests have been
                # received than allowed.
-               if counter >= self.limit:
+               if requests >= self.limit:
                        return True
 
-               # Increment the counter
-               await self.increment_counter()
+               # Increment the request counter
+               await self.increment_requests()
 
                # If not ratelimited, write some headers
-               self.write_headers(counter=counter)
-
-       @property
-       def key(self):
-               return "%s-%s" % (self.prefix, self.now.strftime("%Y-%m-%d-%H:%M"))
+               self.write_headers(requests=requests)
 
-       @property
-       def keys_to_check(self):
-               for minute in range(self.minutes + 1):
-                       when = self.now - datetime.timedelta(minutes=minute)
-
-                       yield "%s-%s" % (self.prefix, when.strftime("%Y-%m-%d-%H:%M"))
-
-       async def get_counter(self):
+       async def get_requests(self):
                """
-                       Returns the number of requests that have been done in
-                       recent time.
+                       Returns the number of requests that have been done in the recent sliding window
                """
-               async with await self.backend.cache.pipeline() as p:
-                       for key in self.keys_to_check:
-                               await p.get(key)
-
-                       # Run the pipeline
-                       res = await p.execute()
-
-               # Return the sum
-               return sum((int(e) for e in res if e))
-
-       def write_headers(self, counter):
+               # Now, rounded down to the minute
+               now = sqlalchemy.func.date_trunc(
+                       "minute", sqlalchemy.func.current_timestamp(),
+               )
+
+               # Go back into the past to see when the sliding window has started
+               since = now - datetime.timedelta(minutes=self.minutes)
+
+               # Sum up all requests
+               stmt = (
+                       sqlalchemy
+                       .select(
+                               sqlalchemy.func.sum(
+                                       ratelimiter.c.requests,
+                               ).label("requests")
+                       )
+                       .where(
+                               ratelimiter.c.key        == self.key,
+                               ratelimiter.c.timestamp  >= since,
+                               ratelimiter.c.address    == self.address,
+                       )
+               )
+
+               return await self.db.select_one(stmt, "requests") or 0
+
+       def write_headers(self, requests):
                # Send the limit to the user
                self.handler.set_header("X-Rate-Limit-Limit", self.limit)
 
                # Send the user how many requests are left for this time window
-               self.handler.set_header("X-Rate-Limit-Remaining", self.limit - counter)
+               self.handler.set_header("X-Rate-Limit-Remaining", self.limit - requests)
 
                # Send when the limit resets
                self.handler.set_header("X-Rate-Limit-Reset", self.expires_at.strftime("%s"))
 
-       async def increment_counter(self):
-               async with await self.backend.cache.pipeline() as p:
-                       # Increment the key
-                       await p.incr(self.key)
-
-                       # Set expiry
-                       await p.expireat(self.key, self.expires_at)
-
-                       # Run the pipeline
-                       await p.execute()
+       async def increment_requests(self):
+               """
+                       Increments the counter that identifies this request
+               """
+               now = sqlalchemy.func.date_trunc(
+                       "minute", sqlalchemy.func.current_timestamp(),
+               )
+
+               # Figure out until when we will need this entry
+               expires_at = now + datetime.timedelta(minutes=self.minutes + 1)
+
+               # Create a new entry to the database
+               insert_stmt = (
+                       sqlalchemy.dialects.postgresql
+                       .insert(
+                               ratelimiter,
+                       )
+                       .values({
+                               "key"        : self.key,
+                               "timestamp"  : now,
+                               "address"    : self.address,
+                               "requests"   : 1,
+                               "expires_at" : expires_at,
+                       })
+               )
+
+               # If the entry exist already, we just increment the counter
+               upsert_stmt = insert_stmt.on_conflict_do_update(
+                       index_elements = [
+                               "key", "timestamp", "address",
+                       ],
+                       set_ = {
+                               "requests" : ratelimiter.c.requests + 1
+                       },
+               )
+
+               await self.db.execute(upsert_stmt)
index 0186a24bc3cd99d72912d105b338ad32c38538ab..b7ca0d610775f35eac88d7dd981858c54570234d 100644 (file)
@@ -642,6 +642,19 @@ CREATE SEQUENCE public.packages_id_seq
 ALTER SEQUENCE public.packages_id_seq OWNED BY public.packages.id;
 
 
+--
+-- Name: ratelimiter; Type: TABLE; Schema: public; Owner: -
+--
+
+CREATE UNLOGGED TABLE public.ratelimiter (
+    key text NOT NULL,
+    "timestamp" timestamp without time zone DEFAULT CURRENT_TIMESTAMP NOT NULL,
+    address inet NOT NULL,
+    requests integer DEFAULT 1 NOT NULL,
+    expires_at timestamp without time zone NOT NULL
+);
+
+
 --
 -- Name: relation_sizes; Type: VIEW; Schema: public; Owner: -
 --
@@ -1421,6 +1434,14 @@ ALTER TABLE ONLY public.packages
     ADD CONSTRAINT packages_pkey PRIMARY KEY (id);
 
 
+--
+-- Name: ratelimiter ratelimiter_unique; Type: CONSTRAINT; Schema: public; Owner: -
+--
+
+ALTER TABLE ONLY public.ratelimiter
+    ADD CONSTRAINT ratelimiter_unique UNIQUE (key, "timestamp", address);
+
+
 --
 -- Name: release_images release_images_pkey; Type: CONSTRAINT; Schema: public; Owner: -
 --