From: Michael Tremer Date: Thu, 10 Jul 2025 15:04:33 +0000 (+0000) Subject: api: Implement a rate limiter for some API requests X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5345ecb9d06c8073c893de5428fe47b6bbec090d;p=pbs.git api: Implement a rate limiter for some API requests Signed-off-by: Michael Tremer --- diff --git a/Makefile.am b/Makefile.am index 83ac3a60..d0c99060 100644 --- a/Makefile.am +++ b/Makefile.am @@ -128,6 +128,7 @@ api_PYTHON = \ src/api/distros.py \ src/api/downloads.py \ src/api/events.py \ + src/api/limiter.py \ src/api/mirrors.py \ src/api/packages.py \ src/api/uploads.py \ diff --git a/src/api/__init__.py b/src/api/__init__.py index eb0a3a70..680a2ffc 100644 --- a/src/api/__init__.py +++ b/src/api/__init__.py @@ -28,6 +28,8 @@ from .. import Backend backend = Backend("/etc/pakfire/pbs.conf") #backend.launch_background_tasks() +from . import limiter + # Initialize the app app = fastapi.FastAPI( title = "Pakfire Build Service API", @@ -59,6 +61,9 @@ async def add_process_time_header(request: fastapi.Request, call_next): return response +# Add a rate limiter +app.add_middleware(limiter.RateLimiterMiddleware) + # Add CORS app.add_middleware( fastapi.middleware.cors.CORSMiddleware, diff --git a/src/api/debuginfo.py b/src/api/debuginfo.py index a4c12664..8badcc9e 100644 --- a/src/api/debuginfo.py +++ b/src/api/debuginfo.py @@ -22,10 +22,10 @@ import fastapi from . import app from . import backend - -# XXX This endpoint need some ratelimiting applied +from . import limiter @app.get("/buildid/{buildid}/debuginfo", include_in_schema=False) +@limiter.limit(limit=100, minutes=10) async def get(buildid: str) -> fastapi.responses.StreamingResponse: # Fetch the package package = await backend.packages.get_by_buildid(buildid) diff --git a/src/api/downloads.py b/src/api/downloads.py index d001c1db..9cc5d7f6 100644 --- a/src/api/downloads.py +++ b/src/api/downloads.py @@ -25,6 +25,7 @@ import stat from . import app from . import backend +from . import limiter from . import util # Create a new router for all endpoints @@ -32,9 +33,8 @@ router = fastapi.APIRouter( prefix="/downloads", ) -# XXX These endpoints need some ratelimiting applied - @router.head("/{path:path}", include_in_schema=False) +@limiter.limit(limit=100, minutes=60, key="downloads") async def head(path: str) -> fastapi.Response: """ Handle any HEAD requests @@ -60,6 +60,7 @@ async def head(path: str) -> fastapi.Response: ) @router.get("/{path:path}", include_in_schema=False) +@limiter.limit(limit=100, minutes=60, key="downloads") async def get( path: str, current_address: ipaddress.IPv6Address | ipaddress.IPv4Address = \ diff --git a/src/api/limiter.py b/src/api/limiter.py new file mode 100644 index 00000000..2e678663 --- /dev/null +++ b/src/api/limiter.py @@ -0,0 +1,92 @@ +############################################################################### +# # +# Pakfire - The IPFire package management system # +# Copyright (C) 2025 Pakfire development team # +# # +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the GNU General Public License as published by # +# the Free Software Foundation, either version 3 of the License, or # +# (at your option) any later version. # +# # +# This program is distributed in the hope that it will be useful, # +# but WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # +# GNU General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see . # +# # +############################################################################### + +import fastapi +import typing + +def _find_handler(app: fastapi.FastAPI, scope) -> typing.Callable | None: + # Process all routes + for route in app.routes: + # Match the route + match, _ = route.matches(scope) + + # If the router is a match, we return the endpoint + if match == fastapi.routing.Match.FULL and hasattr(route, "endpoint"): + return route.endpoint + + +class RateLimiterMiddleware(fastapi.applications.BaseHTTPMiddleware): + async def dispatch(self, request: fastapi.Request, call_next): + # Fetch the app + app: fastapi.FastAPI = request.app + + # Fetch the backend + backend = app.state.backend + + # Find the handler + handler = _find_handler(app, request.scope) + + # Fetch the limiter settings + limiter = getattr(handler, "_limiter", None) + + # The limiter does not seem to be configured + if limiter is None: + return await call_next(request) + + # Create the limiter + limiter = backend.ratelimiter(request, **limiter) + + # Check if the request should be ratelimited + async with backend.db as session: + ratelimited = await limiter.is_ratelimited() + + # If the request has been rate-limited, + # we will send a response with status code 429. + if ratelimited: + response = fastapi.responses.JSONResponse( + { "error" : "Ratelimit exceeded" }, status_code=429, + ) + + # Otherwise we process the request as usual + else: + response = await call_next(request) + + # Write the response headers + limiter.write_headers(response) + + return response + + +def limit(*, minutes, limit, key=None): + """ + This decorator takes a limit of how many requests per minute can be processed. + """ + limiter = { + "minutes" : minutes, + "limit" : limit, + "key" : key, + } + + # Store the configuration with the handler + def decorator(handler): + setattr(handler, "_limiter", limiter) + return handler + + return decorator diff --git a/src/api/packages.py b/src/api/packages.py index 6f8f0671..456a4d3b 100644 --- a/src/api/packages.py +++ b/src/api/packages.py @@ -27,6 +27,7 @@ from uuid import UUID from . import apiv1 from . import app from . import backend +from . import limiter from ..packages import Package, File @@ -90,9 +91,8 @@ async def get_filelist_by_uuid( return [file async for file in await package.get_files()] -# XXX This endpoint need some ratelimiting applied - @app.get("/packages/{uuid:uuid}/download/{path:path}", include_in_schema=False) +@limiter.limit(limit=100, minutes=60) async def download_file( path: str, package: Package = fastapi.Depends(get_package_by_uuid), diff --git a/src/buildservice/ratelimiter.py b/src/buildservice/ratelimiter.py index cf5a6844..5e1b4f5f 100644 --- a/src/buildservice/ratelimiter.py +++ b/src/buildservice/ratelimiter.py @@ -20,6 +20,7 @@ ############################################################################### import datetime +import fastapi import ipaddress import sqlalchemy @@ -79,9 +80,11 @@ class RateLimiter(base.Object): class RateLimiterRequest(base.Object): - def init(self, request, handler, *, minutes, limit, key=None): - self.request = request - self.handler = handler + def init(self, request: fastapi.Request, *, minutes, limit, key=None): + self.request: fastapi.Request = request + + # Number of requests in the current window + self.requests = None # Save the limits self.minutes = minutes @@ -90,16 +93,14 @@ class RateLimiterRequest(base.Object): # Create a default key if none given if key is None: key = "%s-%s-%s" % ( - self.request.host, self.request.method, - self.request.path, + self.request.url.hostname, + self.request.url.path, ) # Store the key and address - self.key = key - self.address = ipaddress.ip_address( - self.request.remote_ip, - ) + self.key = key + self.address, port = self.request.client # What is the current time? self.now = datetime.datetime.utcnow() @@ -111,19 +112,16 @@ class RateLimiterRequest(base.Object): """ Returns True if the request is prohibited by the rate limiter """ - requests = await self.get_requests() + self.requests = await self.get_requests() # The client is rate-limited when more requests have been # received than allowed. - if requests >= self.limit: + if self.requests >= self.limit: return True # Increment the request counter await self.increment_requests() - # If not ratelimited, write some headers - self.write_headers(requests=requests) - async def get_requests(self): """ Returns the number of requests that have been done in the recent sliding window @@ -153,15 +151,25 @@ class RateLimiterRequest(base.Object): return await self.db.select_one(stmt, "requests") or 0 - def write_headers(self, requests): + def write_headers(self, response: fastapi.Response): # Send the limit to the user - self.handler.set_header("X-Rate-Limit-Limit", self.limit) + response.headers.append("X-Rate-Limit-Limit", "%s" % self.limit) # Send the user how many requests are left for this time window - self.handler.set_header("X-Rate-Limit-Remaining", self.limit - requests) + response.headers.append( + "X-Rate-Limit-Remaining", "%s" % (self.limit - self.requests), + ) # Send when the limit resets - self.handler.set_header("X-Rate-Limit-Reset", self.expires_at.strftime("%s")) + response.headers.append( + "X-Rate-Limit-Reset", self.expires_at.strftime("%a, %d %b %Y %H:%M:%S GMT"), + ) + + # Send Retry-After (in seconds) + if self.requests >= self.limit: + response.headers.append( + "Retry-After", "%.0f" % (self.expires_at - self.now).total_seconds(), + ) async def increment_requests(self): """