From: Michael Tremer Date: Tue, 3 Mar 2026 16:23:38 +0000 (+0000) Subject: api: Implement a basic rate limiter X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=0de55d705c09fbe4a9ffa4ebfac396329ea606fc;p=dbl.git api: Implement a basic rate limiter This code is borrowed from Pakfire. Signed-off-by: Michael Tremer --- diff --git a/Makefile.am b/Makefile.am index 7aa7bf8..3974da7 100644 --- a/Makefile.am +++ b/Makefile.am @@ -61,6 +61,7 @@ dist_pkgpython_PYTHON = \ src/dbl/i18n.py \ src/dbl/lists.py \ src/dbl/logger.py \ + src/dbl/ratelimiter.py \ src/dbl/reports.py \ src/dbl/sources.py \ src/dbl/users.py \ diff --git a/src/database.sql b/src/database.sql index 6e1b210..ff5f411 100644 --- a/src/database.sql +++ b/src/database.sql @@ -2,7 +2,7 @@ -- PostgreSQL database dump -- -\restrict 3D6w6Ba6aGIuuPKNG7TDbNdKg6QKBWtkvOqI9CgPZ9Rmbc8lGYUpRjUvPeflzFA +\restrict eKEDDgUfPNvnW4rVUzYeVjaWYWpaanMRnCSzmfxb3B7N83rA9qmJYZl5gCT8wnK -- Dumped from database version 17.7 (Debian 17.7-0+deb13u1) -- Dumped by pg_dump version 17.7 (Debian 17.7-0+deb13u1) @@ -271,6 +271,19 @@ CREATE SEQUENCE public.nameservers_id_seq ALTER SEQUENCE public.nameservers_id_seq OWNED BY public.nameservers.id; +-- +-- Name: ratelimiter; Type: TABLE; Schema: public; Owner: - +-- + +CREATE TABLE public.ratelimiter ( + key text NOT NULL, + "timestamp" timestamp with time zone DEFAULT CURRENT_TIMESTAMP NOT NULL, + bucket text NOT NULL, + requests integer DEFAULT 1 NOT NULL, + expires_at timestamp with time zone NOT NULL +); + + -- -- Name: report_comments; Type: TABLE; Schema: public; Owner: - -- @@ -470,6 +483,14 @@ ALTER TABLE ONLY public.nameservers ADD CONSTRAINT nameservers_pkey PRIMARY KEY (id); +-- +-- Name: ratelimiter ratelimiter_unique; Type: CONSTRAINT; Schema: public; Owner: - +-- + +ALTER TABLE ONLY public.ratelimiter + ADD CONSTRAINT ratelimiter_unique UNIQUE (key, "timestamp", bucket); + + -- -- Name: report_comments report_comments_pkey; Type: CONSTRAINT; Schema: public; Owner: - -- @@ -683,5 +704,5 @@ ALTER TABLE ONLY public.sources -- PostgreSQL database dump complete -- -\unrestrict 3D6w6Ba6aGIuuPKNG7TDbNdKg6QKBWtkvOqI9CgPZ9Rmbc8lGYUpRjUvPeflzFA +\unrestrict eKEDDgUfPNvnW4rVUzYeVjaWYWpaanMRnCSzmfxb3B7N83rA9qmJYZl5gCT8wnK diff --git a/src/dbl/__init__.py b/src/dbl/__init__.py index 43ff8ea..a6db0d3 100644 --- a/src/dbl/__init__.py +++ b/src/dbl/__init__.py @@ -41,6 +41,7 @@ from . import auth from . import database from . import domains from . import lists +from . import ratelimiter from . import reports from . import sources from . import users @@ -148,6 +149,10 @@ class Backend(object): def lists(self): return lists.Lists(self) + @functools.cached_property + def ratelimiter(self): + return ratelimiter.RateLimiter(self) + @functools.cached_property def reports(self): return reports.Reports(self) diff --git a/src/dbl/api/__init__.py b/src/dbl/api/__init__.py index ced0273..b3aa397 100644 --- a/src/dbl/api/__init__.py +++ b/src/dbl/api/__init__.py @@ -41,6 +41,7 @@ app = fastapi.FastAPI( debug = True, ) app.add_middleware(middlewares.DatabaseSessionMiddleware) +app.add_middleware(middlewares.RateLimiterMiddleware) app.add_middleware(middlewares.ProcessingTimeMiddleware) # Initialize the backend @@ -96,6 +97,23 @@ async def require_current_user( return user +def limit(*, minutes, limit, key=None): + """ + This decorator takes a limit of how many requests per minute can be processed. + """ + limiter = { + "minutes" : minutes, + "limit" : limit, + "key" : key, + } + + # Store the configuration with the handler + def decorator(handler): + setattr(handler, "_limiter", limiter) + return handler + + return decorator + # Import any endpoints from . import domains from . import lists diff --git a/src/dbl/api/middlewares.py b/src/dbl/api/middlewares.py index 666acee..73b331f 100644 --- a/src/dbl/api/middlewares.py +++ b/src/dbl/api/middlewares.py @@ -20,6 +20,7 @@ import fastapi import time +import typing class DatabaseSessionMiddleware(fastapi.applications.BaseHTTPMiddleware): async def dispatch(self, request: fastapi.Request, call_next): @@ -49,3 +50,56 @@ class ProcessingTimeMiddleware(fastapi.applications.BaseHTTPMiddleware): response.headers["X-Process-Time"] = "%.2fms" % ((t_end - t_start) * 1000.0) return response + + +class RateLimiterMiddleware(fastapi.applications.BaseHTTPMiddleware): + async def dispatch(self, request: fastapi.Request, call_next): + # Fetch the app + app: fastapi.FastAPI = request.app + + # Fetch the backend + backend = app.state.backend + + # Find the handler + handler = self.find_handler(app, request.scope) + + # Fetch the limiter settings + limiter = getattr(handler, "_limiter", None) + + # The limiter does not seem to be configured + if limiter is None: + return await call_next(request) + + # Create the limiter + limiter = backend.ratelimiter(request, **limiter) + + # Check if the request should be ratelimited + async with backend.db as session: + ratelimited = await limiter.is_ratelimited() + + # If the request has been rate-limited, + # we will send a response with status code 429. + if ratelimited: + response = fastapi.responses.JSONResponse( + { "error" : "Ratelimit exceeded" }, status_code=429, + ) + + # Otherwise we process the request as usual + else: + response = await call_next(request) + + # Write the response headers + limiter.write_headers(response) + + return response + + @staticmethod + def find_handler(app: fastapi.FastAPI, scope) -> typing.Callable | None: + # Process all routes + for route in app.routes: + # Match the route + match, _ = route.matches(scope) + + # If the router is a match, we return the endpoint + if match == fastapi.routing.Match.FULL and hasattr(route, "endpoint"): + return route.endpoint diff --git a/src/dbl/api/reports.py b/src/dbl/api/reports.py index 3fb57d1..b7d5d67 100644 --- a/src/dbl/api/reports.py +++ b/src/dbl/api/reports.py @@ -29,9 +29,10 @@ from .. import reports from .. import users # Import the main app -from . import require_current_user from . import app from . import backend +from . import limit +from . import require_current_user class CreateReport(pydantic.BaseModel): # List @@ -190,6 +191,7 @@ async def submit_comment( RSS """ @router.get(".rss") +@limit(limit=25, minutes=60) async def rss( open: bool | None = None, limit: int = 100, diff --git a/src/dbl/database.py b/src/dbl/database.py index e86045f..22e5d9a 100644 --- a/src/dbl/database.py +++ b/src/dbl/database.py @@ -200,6 +200,29 @@ class Database(object): yield row(**object) + async def select_one(self, stmt, attr=None): + """ + Returns exactly one row + """ + result = await self.execute(stmt) + + # Process mappings + result = result.mappings() + + # Extract exactly one result (or none) + result = result.one_or_none() + + # Return if we have no result + if result is None: + return + + # Return the whole result if no attribute was requested + elif attr is None: + return result + + # Otherwise return the attribute only + return getattr(result, attr) + async def fetch(self, stmt, batch_size=128): """ Fetches objects of the given type diff --git a/src/dbl/ratelimiter.py b/src/dbl/ratelimiter.py new file mode 100644 index 0000000..0b56d48 --- /dev/null +++ b/src/dbl/ratelimiter.py @@ -0,0 +1,215 @@ +############################################################################### +# # +# dbl - A Domain Blocklist Compositor For IPFire # +# Copyright (C) 2026 IPFire Development Team # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see . # +# # +############################################################################### + +import datetime +import fastapi +import sqlmodel + +import sqlalchemy +from sqlalchemy import Column, DateTime, Integer, Text, UniqueConstraint +from sqlalchemy.dialects.postgresql import INET + + +#from . import base +#from . import database + +ratelimiter = sqlalchemy.Table( + "ratelimiter", sqlmodel.SQLModel.metadata, + + # Key + Column("key", Text, nullable=False), + + # Timestamp + Column("timestamp", DateTime(timezone=False), nullable=False, + server_default=sqlalchemy.func.current_timestamp()), + + # Bucket + Column("bucket", Text, nullable=False), + + # Requests + Column("requests", Integer, nullable=False, default=1), + + # Expires At + Column("expires_at", DateTime(timezone=False), nullable=False), + + # Unique constraint + UniqueConstraint("key", "timestamp", "bucket", name="ratelimiter_unique") +) + +class RateLimiter(object): + def __init__(self, backend): + self.backend = backend + + def __call__(self, *args, **kwargs): + """ + Launch a new request + """ + return RateLimiterRequest(self.backend, *args, **kwargs) + + async def cleanup(self): + """ + Called to cleanup the ratelimiter from expired entries + """ + # Delete everything that has expired in the past + stmt = ( + ratelimiter + .delete() + .where( + ratelimiter.c.expires_at <= sqlalchemy.func.current_timestamp(), + ) + ) + + # Run the query + async with await self.backend.db.transaction(): + await self.backend.db.execute(stmt) + + +class RateLimiterRequest(object): + def __init__(self, backend, request: fastapi.Request, *, minutes, limit, key=None): + self.backend = backend + + # Store the request + self.request: fastapi.Request = request + + # Number of requests in the current window + self.requests = None + + # Save the limits + self.minutes = minutes + self.limit = limit + + # Create a default key if none given + if key is None: + key = "%s-%s-%s" % ( + self.request.method, + self.request.url.hostname, + self.request.url.path, + ) + + # Store the key and address + self.key = key + self.address, port = self.request.client + + # What is the current time? + self.now = datetime.datetime.utcnow() + + # When to expire? + self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1) + + async def is_ratelimited(self): + """ + Returns True if the request is prohibited by the rate limiter + """ + self.requests = await self.get_requests() + + # The client is rate-limited when more requests have been + # received than allowed. + if self.requests >= self.limit: + return True + + # Increment the request counter + await self.increment_requests() + + async def get_requests(self): + """ + Returns the number of requests that have been done in the recent sliding window + """ + # Now, rounded down to the minute + now = sqlalchemy.func.date_trunc( + "minute", sqlalchemy.func.current_timestamp(), + ) + + # Go back into the past to see when the sliding window has started + since = now - datetime.timedelta(minutes=self.minutes) + + # Sum up all requests + stmt = ( + sqlalchemy + .select( + sqlalchemy.func.sum( + ratelimiter.c.requests, + ).label("requests") + ) + .where( + ratelimiter.c.key == self.key, + ratelimiter.c.timestamp >= since, + ratelimiter.c.bucket == "%s" % self.address, + ) + ) + + return await self.backend.db.select_one(stmt, "requests") or 0 + + def write_headers(self, response: fastapi.Response): + # Send the limit to the user + response.headers.append("X-Rate-Limit-Limit", "%s" % self.limit) + + # Send the user how many requests are left for this time window + response.headers.append( + "X-Rate-Limit-Remaining", "%s" % (self.limit - self.requests), + ) + + # Send when the limit resets + response.headers.append( + "X-Rate-Limit-Reset", self.expires_at.strftime("%a, %d %b %Y %H:%M:%S GMT"), + ) + + # Send Retry-After (in seconds) + if self.requests >= self.limit: + response.headers.append( + "Retry-After", "%.0f" % (self.expires_at - self.now).total_seconds(), + ) + + async def increment_requests(self): + """ + Increments the counter that identifies this request + """ + now = sqlalchemy.func.date_trunc( + "minute", sqlalchemy.func.current_timestamp(), + ) + + # Figure out until when we will need this entry + expires_at = now + datetime.timedelta(minutes=self.minutes + 1) + + # Create a new entry to the database + insert_stmt = ( + sqlalchemy.dialects.postgresql + .insert( + ratelimiter, + ) + .values({ + "key" : self.key, + "timestamp" : now, + "bucket" : "%s" % self.address, + "requests" : 1, + "expires_at" : expires_at, + }) + ) + + # If the entry exist already, we just increment the counter + upsert_stmt = insert_stmt.on_conflict_do_update( + index_elements = [ + "key", "timestamp", "bucket", + ], + set_ = { + "requests" : ratelimiter.c.requests + 1 + }, + ) + + await self.backend.db.execute(upsert_stmt)