import fastapi
import time
+import typing
class DatabaseSessionMiddleware(fastapi.applications.BaseHTTPMiddleware):
async def dispatch(self, request: fastapi.Request, call_next):
response.headers["X-Process-Time"] = "%.2fms" % ((t_end - t_start) * 1000.0)
return response
+
+
+class RateLimiterMiddleware(fastapi.applications.BaseHTTPMiddleware):
+ async def dispatch(self, request: fastapi.Request, call_next):
+ # Fetch the app
+ app: fastapi.FastAPI = request.app
+
+ # Fetch the backend
+ backend = app.state.backend
+
+ # Find the handler
+ handler = self.find_handler(app, request.scope)
+
+ # Fetch the limiter settings
+ limiter = getattr(handler, "_limiter", None)
+
+ # The limiter does not seem to be configured
+ if limiter is None:
+ return await call_next(request)
+
+ # Create the limiter
+ limiter = backend.ratelimiter(request, **limiter)
+
+ # Check if the request should be ratelimited
+ async with backend.db as session:
+ ratelimited = await limiter.is_ratelimited()
+
+ # If the request has been rate-limited,
+ # we will send a response with status code 429.
+ if ratelimited:
+ response = fastapi.responses.JSONResponse(
+ { "error" : "Ratelimit exceeded" }, status_code=429,
+ )
+
+ # Otherwise we process the request as usual
+ else:
+ response = await call_next(request)
+
+ # Write the response headers
+ limiter.write_headers(response)
+
+ return response
+
+ @staticmethod
+ def find_handler(app: fastapi.FastAPI, scope) -> typing.Callable | None:
+ # Process all routes
+ for route in app.routes:
+ # Match the route
+ match, _ = route.matches(scope)
+
+ # If the router is a match, we return the endpoint
+ if match == fastapi.routing.Match.FULL and hasattr(route, "endpoint"):
+ return route.endpoint
--- /dev/null
+###############################################################################
+# #
+# dbl - A Domain Blocklist Compositor For IPFire #
+# Copyright (C) 2026 IPFire Development Team #
+# #
+# This program is free software: you can redistribute it and/or modify #
+# it under the terms of the GNU General Public License as published by #
+# the Free Software Foundation, either version 3 of the License, or #
+# (at your option) any later version. #
+# #
+# This program is distributed in the hope that it will be useful, #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
+# GNU General Public License for more details. #
+# #
+# You should have received a copy of the GNU General Public License #
+# along with this program. If not, see <http://www.gnu.org/licenses/>. #
+# #
+###############################################################################
+
+import datetime
+import fastapi
+import sqlmodel
+
+import sqlalchemy
+from sqlalchemy import Column, DateTime, Integer, Text, UniqueConstraint
+from sqlalchemy.dialects.postgresql import INET
+
+
+#from . import base
+#from . import database
+
+ratelimiter = sqlalchemy.Table(
+ "ratelimiter", sqlmodel.SQLModel.metadata,
+
+ # Key
+ Column("key", Text, nullable=False),
+
+ # Timestamp
+ Column("timestamp", DateTime(timezone=False), nullable=False,
+ server_default=sqlalchemy.func.current_timestamp()),
+
+ # Bucket
+ Column("bucket", Text, nullable=False),
+
+ # Requests
+ Column("requests", Integer, nullable=False, default=1),
+
+ # Expires At
+ Column("expires_at", DateTime(timezone=False), nullable=False),
+
+ # Unique constraint
+ UniqueConstraint("key", "timestamp", "bucket", name="ratelimiter_unique")
+)
+
+class RateLimiter(object):
+ def __init__(self, backend):
+ self.backend = backend
+
+ def __call__(self, *args, **kwargs):
+ """
+ Launch a new request
+ """
+ return RateLimiterRequest(self.backend, *args, **kwargs)
+
+ async def cleanup(self):
+ """
+ Called to cleanup the ratelimiter from expired entries
+ """
+ # Delete everything that has expired in the past
+ stmt = (
+ ratelimiter
+ .delete()
+ .where(
+ ratelimiter.c.expires_at <= sqlalchemy.func.current_timestamp(),
+ )
+ )
+
+ # Run the query
+ async with await self.backend.db.transaction():
+ await self.backend.db.execute(stmt)
+
+
+class RateLimiterRequest(object):
+ def __init__(self, backend, request: fastapi.Request, *, minutes, limit, key=None):
+ self.backend = backend
+
+ # Store the request
+ self.request: fastapi.Request = request
+
+ # Number of requests in the current window
+ self.requests = None
+
+ # Save the limits
+ self.minutes = minutes
+ self.limit = limit
+
+ # Create a default key if none given
+ if key is None:
+ key = "%s-%s-%s" % (
+ self.request.method,
+ self.request.url.hostname,
+ self.request.url.path,
+ )
+
+ # Store the key and address
+ self.key = key
+ self.address, port = self.request.client
+
+ # What is the current time?
+ self.now = datetime.datetime.utcnow()
+
+ # When to expire?
+ self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1)
+
+ async def is_ratelimited(self):
+ """
+ Returns True if the request is prohibited by the rate limiter
+ """
+ self.requests = await self.get_requests()
+
+ # The client is rate-limited when more requests have been
+ # received than allowed.
+ if self.requests >= self.limit:
+ return True
+
+ # Increment the request counter
+ await self.increment_requests()
+
+ async def get_requests(self):
+ """
+ Returns the number of requests that have been done in the recent sliding window
+ """
+ # Now, rounded down to the minute
+ now = sqlalchemy.func.date_trunc(
+ "minute", sqlalchemy.func.current_timestamp(),
+ )
+
+ # Go back into the past to see when the sliding window has started
+ since = now - datetime.timedelta(minutes=self.minutes)
+
+ # Sum up all requests
+ stmt = (
+ sqlalchemy
+ .select(
+ sqlalchemy.func.sum(
+ ratelimiter.c.requests,
+ ).label("requests")
+ )
+ .where(
+ ratelimiter.c.key == self.key,
+ ratelimiter.c.timestamp >= since,
+ ratelimiter.c.bucket == "%s" % self.address,
+ )
+ )
+
+ return await self.backend.db.select_one(stmt, "requests") or 0
+
+ def write_headers(self, response: fastapi.Response):
+ # Send the limit to the user
+ response.headers.append("X-Rate-Limit-Limit", "%s" % self.limit)
+
+ # Send the user how many requests are left for this time window
+ response.headers.append(
+ "X-Rate-Limit-Remaining", "%s" % (self.limit - self.requests),
+ )
+
+ # Send when the limit resets
+ response.headers.append(
+ "X-Rate-Limit-Reset", self.expires_at.strftime("%a, %d %b %Y %H:%M:%S GMT"),
+ )
+
+ # Send Retry-After (in seconds)
+ if self.requests >= self.limit:
+ response.headers.append(
+ "Retry-After", "%.0f" % (self.expires_at - self.now).total_seconds(),
+ )
+
+ async def increment_requests(self):
+ """
+ Increments the counter that identifies this request
+ """
+ now = sqlalchemy.func.date_trunc(
+ "minute", sqlalchemy.func.current_timestamp(),
+ )
+
+ # Figure out until when we will need this entry
+ expires_at = now + datetime.timedelta(minutes=self.minutes + 1)
+
+ # Create a new entry to the database
+ insert_stmt = (
+ sqlalchemy.dialects.postgresql
+ .insert(
+ ratelimiter,
+ )
+ .values({
+ "key" : self.key,
+ "timestamp" : now,
+ "bucket" : "%s" % self.address,
+ "requests" : 1,
+ "expires_at" : expires_at,
+ })
+ )
+
+ # If the entry exist already, we just increment the counter
+ upsert_stmt = insert_stmt.on_conflict_do_update(
+ index_elements = [
+ "key", "timestamp", "bucket",
+ ],
+ set_ = {
+ "requests" : ratelimiter.c.requests + 1
+ },
+ )
+
+ await self.backend.db.execute(upsert_stmt)