]> git.ipfire.org Git - dbl.git/commitdiff
api: Implement a basic rate limiter
authorMichael Tremer <michael.tremer@ipfire.org>
Tue, 3 Mar 2026 16:23:38 +0000 (16:23 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Tue, 3 Mar 2026 16:24:31 +0000 (16:24 +0000)
This code is borrowed from Pakfire.

Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
Makefile.am
src/database.sql
src/dbl/__init__.py
src/dbl/api/__init__.py
src/dbl/api/middlewares.py
src/dbl/api/reports.py
src/dbl/database.py
src/dbl/ratelimiter.py [new file with mode: 0644]

index 7aa7bf80b8798f751479ca393d7c0ea438685208..3974da7a723eccb496cb09b29bbfeaaec87d8458 100644 (file)
@@ -61,6 +61,7 @@ dist_pkgpython_PYTHON = \
        src/dbl/i18n.py \
        src/dbl/lists.py \
        src/dbl/logger.py \
+       src/dbl/ratelimiter.py \
        src/dbl/reports.py \
        src/dbl/sources.py \
        src/dbl/users.py \
index 6e1b2108d3b6bf781da097a2cc1aa3eba2aba6a7..ff5f411c155ada60270a73648cee0cb791844fc2 100644 (file)
@@ -2,7 +2,7 @@
 -- PostgreSQL database dump
 --
 
-\restrict 3D6w6Ba6aGIuuPKNG7TDbNdKg6QKBWtkvOqI9CgPZ9Rmbc8lGYUpRjUvPeflzFA
+\restrict eKEDDgUfPNvnW4rVUzYeVjaWYWpaanMRnCSzmfxb3B7N83rA9qmJYZl5gCT8wnK
 
 -- Dumped from database version 17.7 (Debian 17.7-0+deb13u1)
 -- Dumped by pg_dump version 17.7 (Debian 17.7-0+deb13u1)
@@ -271,6 +271,19 @@ CREATE SEQUENCE public.nameservers_id_seq
 ALTER SEQUENCE public.nameservers_id_seq OWNED BY public.nameservers.id;
 
 
+--
+-- Name: ratelimiter; Type: TABLE; Schema: public; Owner: -
+--
+
+CREATE TABLE public.ratelimiter (
+    key text NOT NULL,
+    "timestamp" timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL,
+    bucket text NOT NULL,
+    requests integer DEFAULT 1 NOT NULL,
+    expires_at timestamp with time zone NOT NULL
+);
+
+
 --
 -- Name: report_comments; Type: TABLE; Schema: public; Owner: -
 --
@@ -470,6 +483,14 @@ ALTER TABLE ONLY public.nameservers
     ADD CONSTRAINT nameservers_pkey PRIMARY KEY (id);
 
 
+--
+-- Name: ratelimiter ratelimiter_unique; Type: CONSTRAINT; Schema: public; Owner: -
+--
+
+ALTER TABLE ONLY public.ratelimiter
+    ADD CONSTRAINT ratelimiter_unique UNIQUE (key, "timestamp", bucket);
+
+
 --
 -- Name: report_comments report_comments_pkey; Type: CONSTRAINT; Schema: public; Owner: -
 --
@@ -683,5 +704,5 @@ ALTER TABLE ONLY public.sources
 -- PostgreSQL database dump complete
 --
 
-\unrestrict 3D6w6Ba6aGIuuPKNG7TDbNdKg6QKBWtkvOqI9CgPZ9Rmbc8lGYUpRjUvPeflzFA
+\unrestrict eKEDDgUfPNvnW4rVUzYeVjaWYWpaanMRnCSzmfxb3B7N83rA9qmJYZl5gCT8wnK
 
index 43ff8ea041261ead86b299181af0cef4da8b8bb4..a6db0d3afe51394d1218f842077e107f8a2f9bef 100644 (file)
@@ -41,6 +41,7 @@ from . import auth
 from . import database
 from . import domains
 from . import lists
+from . import ratelimiter
 from . import reports
 from . import sources
 from . import users
@@ -148,6 +149,10 @@ class Backend(object):
        def lists(self):
                return lists.Lists(self)
 
+       @functools.cached_property
+       def ratelimiter(self):
+               return ratelimiter.RateLimiter(self)
+
        @functools.cached_property
        def reports(self):
                return reports.Reports(self)
index ced0273688b71ab21136972748b7b9573f0e15fa..b3aa397739ec336e0e7f00cee00559d8f153dde9 100644 (file)
@@ -41,6 +41,7 @@ app = fastapi.FastAPI(
        debug = True,
 )
 app.add_middleware(middlewares.DatabaseSessionMiddleware)
+app.add_middleware(middlewares.RateLimiterMiddleware)
 app.add_middleware(middlewares.ProcessingTimeMiddleware)
 
 # Initialize the backend
@@ -96,6 +97,23 @@ async def require_current_user(
 
        return user
 
+def limit(*, minutes, limit, key=None):
+       """
+               This decorator takes a limit of how many requests per minute can be processed.
+       """
+       limiter = {
+               "minutes" : minutes,
+               "limit"   : limit,
+               "key"     : key,
+       }
+
+       # Store the configuration with the handler
+       def decorator(handler):
+               setattr(handler, "_limiter", limiter)
+               return handler
+
+       return decorator
+
 # Import any endpoints
 from . import domains
 from . import lists
index 666acee091f36c59c2a7740794d6afe3bd6899c9..73b331f84af70bc82e08704d9669cee9ca7d5518 100644 (file)
@@ -20,6 +20,7 @@
 
 import fastapi
 import time
+import typing
 
 class DatabaseSessionMiddleware(fastapi.applications.BaseHTTPMiddleware):
        async def dispatch(self, request: fastapi.Request, call_next):
@@ -49,3 +50,56 @@ class ProcessingTimeMiddleware(fastapi.applications.BaseHTTPMiddleware):
                response.headers["X-Process-Time"] = "%.2fms" % ((t_end - t_start) * 1000.0)
 
                return response
+
+
+class RateLimiterMiddleware(fastapi.applications.BaseHTTPMiddleware):
+       async def dispatch(self, request: fastapi.Request, call_next):
+               # Fetch the app
+               app: fastapi.FastAPI = request.app
+
+               # Fetch the backend
+               backend = app.state.backend
+
+               # Find the handler
+               handler = self.find_handler(app, request.scope)
+
+               # Fetch the limiter settings
+               limiter = getattr(handler, "_limiter", None)
+
+               # The limiter does not seem to be configured
+               if limiter is None:
+                       return await call_next(request)
+
+               # Create the limiter
+               limiter = backend.ratelimiter(request, **limiter)
+
+               # Check if the request should be ratelimited
+               async with backend.db as session:
+                       ratelimited = await limiter.is_ratelimited()
+
+               # If the request has been rate-limited,
+               # we will send a response with status code 429.
+               if ratelimited:
+                       response = fastapi.responses.JSONResponse(
+                               { "error" : "Ratelimit exceeded" }, status_code=429,
+                       )
+
+               # Otherwise we process the request as usual
+               else:
+                       response = await call_next(request)
+
+               # Write the response headers
+               limiter.write_headers(response)
+
+               return response
+
+       @staticmethod
+       def find_handler(app: fastapi.FastAPI, scope) -> typing.Callable | None:
+               # Process all routes
+               for route in app.routes:
+                       # Match the route
+                       match, _ = route.matches(scope)
+
+                       # If the router is a match, we return the endpoint
+                       if match == fastapi.routing.Match.FULL and hasattr(route, "endpoint"):
+                               return route.endpoint
index 3fb57d1c55c7051a05822be98e917739349ffad0..b7d5d670b25b618c852fdf6613b7637b702a12cb 100644 (file)
@@ -29,9 +29,10 @@ from .. import reports
 from .. import users
 
 # Import the main app
-from . import require_current_user
 from . import app
 from . import backend
+from . import limit
+from . import require_current_user
 
 class CreateReport(pydantic.BaseModel):
        # List
@@ -190,6 +191,7 @@ async def submit_comment(
        RSS
 """
 @router.get(".rss")
+@limit(limit=25, minutes=60)
 async def rss(
                open: bool | None = None,
                limit: int = 100,
index e86045fa6b54cb4c6d02309f5a78a1fe2b28b2fc..22e5d9ada671adc8ca02610728757cfa8948b8d3 100644 (file)
@@ -200,6 +200,29 @@ class Database(object):
 
                        yield row(**object)
 
+       async def select_one(self, stmt, attr=None):
+               """
+                       Returns exactly one row
+               """
+               result = await self.execute(stmt)
+
+               # Process mappings
+               result = result.mappings()
+
+               # Extract exactly one result (or none)
+               result = result.one_or_none()
+
+               # Return if we have no result
+               if result is None:
+                       return
+
+               # Return the whole result if no attribute was requested
+               elif attr is None:
+                       return result
+
+               # Otherwise return the attribute only
+               return getattr(result, attr)
+
        async def fetch(self, stmt, batch_size=128):
                """
                        Fetches objects of the given type
diff --git a/src/dbl/ratelimiter.py b/src/dbl/ratelimiter.py
new file mode 100644 (file)
index 0000000..0b56d48
--- /dev/null
@@ -0,0 +1,215 @@
+###############################################################################
+#                                                                             #
+# dbl - A Domain Blocklist Compositor For IPFire                              #
+# Copyright (C) 2026 IPFire Development Team                                  #
+#                                                                             #
+# This program is free software: you can redistribute it and/or modify        #
+# it under the terms of the GNU General Public License as published by        #
+# the Free Software Foundation, either version 3 of the License, or           #
+# (at your option) any later version.                                         #
+#                                                                             #
+# This program is distributed in the hope that it will be useful,             #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
+# GNU General Public License for more details.                                #
+#                                                                             #
+# You should have received a copy of the GNU General Public License           #
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
+#                                                                             #
+###############################################################################
+
+import datetime
+import fastapi
+import sqlmodel
+
+import sqlalchemy
+from sqlalchemy import Column, DateTime, Integer, Text, UniqueConstraint
+from sqlalchemy.dialects.postgresql import INET
+
+
+#from . import base
+#from . import database
+
+ratelimiter = sqlalchemy.Table(
+       "ratelimiter", sqlmodel.SQLModel.metadata,
+
+       # Key
+       Column("key", Text, nullable=False),
+
+       # Timestamp
+       Column("timestamp", DateTime(timezone=False), nullable=False,
+               server_default=sqlalchemy.func.current_timestamp()),
+
+       # Bucket
+       Column("bucket", Text, nullable=False),
+
+       # Requests
+       Column("requests", Integer, nullable=False, default=1),
+
+       # Expires At
+       Column("expires_at", DateTime(timezone=False), nullable=False),
+
+       # Unique constraint
+       UniqueConstraint("key", "timestamp", "bucket", name="ratelimiter_unique")
+)
+
+class RateLimiter(object):
+       def __init__(self, backend):
+               self.backend = backend
+
+       def __call__(self, *args, **kwargs):
+               """
+                       Launch a new request
+               """
+               return RateLimiterRequest(self.backend, *args, **kwargs)
+
+       async def cleanup(self):
+               """
+                       Called to cleanup the ratelimiter from expired entries
+               """
+               # Delete everything that has expired in the past
+               stmt = (
+                       ratelimiter
+                       .delete()
+                       .where(
+                               ratelimiter.c.expires_at <= sqlalchemy.func.current_timestamp(),
+                       )
+               )
+
+               # Run the query
+               async with await self.backend.db.transaction():
+                       await self.backend.db.execute(stmt)
+
+
+class RateLimiterRequest(object):
+       def __init__(self, backend, request: fastapi.Request, *, minutes, limit, key=None):
+               self.backend = backend
+
+               # Store the request
+               self.request: fastapi.Request = request
+
+               # Number of requests in the current window
+               self.requests = None
+
+               # Save the limits
+               self.minutes = minutes
+               self.limit   = limit
+
+               # Create a default key if none given
+               if key is None:
+                       key = "%s-%s-%s" % (
+                               self.request.method,
+                               self.request.url.hostname,
+                               self.request.url.path,
+                       )
+
+               # Store the key and address
+               self.key = key
+               self.address, port = self.request.client
+
+               # What is the current time?
+               self.now = datetime.datetime.utcnow()
+
+               # When to expire?
+               self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1)
+
+       async def is_ratelimited(self):
+               """
+                       Returns True if the request is prohibited by the rate limiter
+               """
+               self.requests = await self.get_requests()
+
+               # The client is rate-limited when more requests have been
+               # received than allowed.
+               if self.requests >= self.limit:
+                       return True
+
+               # Increment the request counter
+               await self.increment_requests()
+
+       async def get_requests(self):
+               """
+                       Returns the number of requests that have been done in the recent sliding window
+               """
+               # Now, rounded down to the minute
+               now = sqlalchemy.func.date_trunc(
+                       "minute", sqlalchemy.func.current_timestamp(),
+               )
+
+               # Go back into the past to see when the sliding window has started
+               since = now - datetime.timedelta(minutes=self.minutes)
+
+               # Sum up all requests
+               stmt = (
+                       sqlalchemy
+                       .select(
+                               sqlalchemy.func.sum(
+                                       ratelimiter.c.requests,
+                               ).label("requests")
+                       )
+                       .where(
+                               ratelimiter.c.key        == self.key,
+                               ratelimiter.c.timestamp  >= since,
+                               ratelimiter.c.bucket     == "%s" % self.address,
+                       )
+               )
+
+               return await self.backend.db.select_one(stmt, "requests") or 0
+
+       def write_headers(self, response: fastapi.Response):
+               # Send the limit to the user
+               response.headers.append("X-Rate-Limit-Limit", "%s" % self.limit)
+
+               # Send the user how many requests are left for this time window
+               response.headers.append(
+                       "X-Rate-Limit-Remaining", "%s" % (self.limit - self.requests),
+               )
+
+               # Send when the limit resets
+               response.headers.append(
+                       "X-Rate-Limit-Reset", self.expires_at.strftime("%a, %d %b %Y %H:%M:%S GMT"),
+               )
+
+               # Send Retry-After (in seconds)
+               if self.requests >= self.limit:
+                       response.headers.append(
+                               "Retry-After", "%.0f" % (self.expires_at - self.now).total_seconds(),
+                       )
+
+       async def increment_requests(self):
+               """
+                       Increments the counter that identifies this request
+               """
+               now = sqlalchemy.func.date_trunc(
+                       "minute", sqlalchemy.func.current_timestamp(),
+               )
+
+               # Figure out until when we will need this entry
+               expires_at = now + datetime.timedelta(minutes=self.minutes + 1)
+
+               # Create a new entry to the database
+               insert_stmt = (
+                       sqlalchemy.dialects.postgresql
+                       .insert(
+                               ratelimiter,
+                       )
+                       .values({
+                               "key"        : self.key,
+                               "timestamp"  : now,
+                               "bucket"     : "%s" % self.address,
+                               "requests"   : 1,
+                               "expires_at" : expires_at,
+                       })
+               )
+
+               # If the entry exist already, we just increment the counter
+               upsert_stmt = insert_stmt.on_conflict_do_update(
+                       index_elements = [
+                               "key", "timestamp", "bucket",
+                       ],
+                       set_ = {
+                               "requests" : ratelimiter.c.requests + 1
+                       },
+               )
+
+               await self.backend.db.execute(upsert_stmt)