]> git.ipfire.org Git - pbs.git/commitdiff
api: Implement a rate limiter for some API requests
authorMichael Tremer <michael.tremer@ipfire.org>
Thu, 10 Jul 2025 15:04:33 +0000 (15:04 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Thu, 10 Jul 2025 15:04:33 +0000 (15:04 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
Makefile.am
src/api/__init__.py
src/api/debuginfo.py
src/api/downloads.py
src/api/limiter.py [new file with mode: 0644]
src/api/packages.py
src/buildservice/ratelimiter.py

index 83ac3a6021ff895d1d2fa195a315c291bbe2ad4f..d0c9906067d0cf5e4309fbde6992af6326a172f1 100644 (file)
@@ -128,6 +128,7 @@ api_PYTHON = \
        src/api/distros.py \
        src/api/downloads.py \
        src/api/events.py \
+       src/api/limiter.py \
        src/api/mirrors.py \
        src/api/packages.py \
        src/api/uploads.py \
index eb0a3a703f4d6320726102d3b3c737a1c3b4bd6e..680a2ffc65a53116338d3ada5ce015763b7eb683 100644 (file)
@@ -28,6 +28,8 @@ from .. import Backend
 backend = Backend("/etc/pakfire/pbs.conf")
 #backend.launch_background_tasks()
 
+from . import limiter
+
 # Initialize the app
 app = fastapi.FastAPI(
        title = "Pakfire Build Service API",
@@ -59,6 +61,9 @@ async def add_process_time_header(request: fastapi.Request, call_next):
 
        return response
 
+# Add a rate limiter
+app.add_middleware(limiter.RateLimiterMiddleware)
+
 # Add CORS
 app.add_middleware(
        fastapi.middleware.cors.CORSMiddleware,
index a4c126644d0cf1905cd5a8b8c1043e7b53e14bd2..8badcc9eebba919bc14d72934929efbfa31c5c22 100644 (file)
@@ -22,10 +22,10 @@ import fastapi
 
 from . import app
 from . import backend
-
-# XXX This endpoint need some ratelimiting applied
+from . import limiter
 
 @app.get("/buildid/{buildid}/debuginfo", include_in_schema=False)
+@limiter.limit(limit=100, minutes=10)
 async def get(buildid: str) -> fastapi.responses.StreamingResponse:
        # Fetch the package
        package = await backend.packages.get_by_buildid(buildid)
index d001c1dbf0ff674b49fbdb16985b739bcedd8d6a..9cc5d7f68b6f44c9b259389da3d9ae7b0a5a5b57 100644 (file)
@@ -25,6 +25,7 @@ import stat
 
 from . import app
 from . import backend
+from . import limiter
 from . import util
 
 # Create a new router for all endpoints
@@ -32,9 +33,8 @@ router = fastapi.APIRouter(
        prefix="/downloads",
 )
 
-# XXX These endpoints need some ratelimiting applied
-
 @router.head("/{path:path}", include_in_schema=False)
+@limiter.limit(limit=100, minutes=60, key="downloads")
 async def head(path: str) -> fastapi.Response:
        """
                Handle any HEAD requests
@@ -60,6 +60,7 @@ async def head(path: str) -> fastapi.Response:
        )
 
 @router.get("/{path:path}", include_in_schema=False)
+@limiter.limit(limit=100, minutes=60, key="downloads")
 async def get(
        path: str,
        current_address: ipaddress.IPv6Address | ipaddress.IPv4Address = \
diff --git a/src/api/limiter.py b/src/api/limiter.py
new file mode 100644 (file)
index 0000000..2e67866
--- /dev/null
@@ -0,0 +1,92 @@
+###############################################################################
+#                                                                             #
+# Pakfire - The IPFire package management system                              #
+# Copyright (C) 2025 Pakfire development team                                 #
+#                                                                             #
+# This program is free software: you can redistribute it and/or modify        #
+# it under the terms of the GNU General Public License as published by        #
+# the Free Software Foundation, either version 3 of the License, or           #
+# (at your option) any later version.                                         #
+#                                                                             #
+# This program is distributed in the hope that it will be useful,             #
+# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
+# GNU General Public License for more details.                                #
+#                                                                             #
+# You should have received a copy of the GNU General Public License           #
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
+#                                                                             #
+###############################################################################
+
+import fastapi
+import typing
+
+def _find_handler(app: fastapi.FastAPI, scope) -> typing.Callable | None:
+       # Process all routes
+       for route in app.routes:
+               # Match the route
+               match, _ = route.matches(scope)
+
+               # If the router is a match, we return the endpoint
+               if match == fastapi.routing.Match.FULL and hasattr(route, "endpoint"):
+                       return route.endpoint
+
+
+class RateLimiterMiddleware(fastapi.applications.BaseHTTPMiddleware):
+       async def dispatch(self, request: fastapi.Request, call_next):
+               # Fetch the app
+               app: fastapi.FastAPI = request.app
+
+               # Fetch the backend
+               backend = app.state.backend
+
+               # Find the handler
+               handler = _find_handler(app, request.scope)
+
+               # Fetch the limiter settings
+               limiter = getattr(handler, "_limiter", None)
+
+               # The limiter does not seem to be configured
+               if limiter is None:
+                       return await call_next(request)
+
+               # Create the limiter
+               limiter = backend.ratelimiter(request, **limiter)
+
+               # Check if the request should be ratelimited
+               async with backend.db as session:
+                       ratelimited = await limiter.is_ratelimited()
+
+               # If the request has been rate-limited,
+               # we will send a response with status code 429.
+               if ratelimited:
+                       response = fastapi.responses.JSONResponse(
+                               { "error" : "Ratelimit exceeded" }, status_code=429,
+                       )
+
+               # Otherwise we process the request as usual
+               else:
+                       response = await call_next(request)
+
+               # Write the response headers
+               limiter.write_headers(response)
+
+               return response
+
+
+def limit(*, minutes, limit, key=None):
+       """
+               This decorator takes a limit of how many requests per minute can be processed.
+       """
+       limiter = {
+               "minutes" : minutes,
+               "limit"   : limit,
+               "key"     : key,
+       }
+
+       # Store the configuration with the handler
+       def decorator(handler):
+               setattr(handler, "_limiter", limiter)
+               return handler
+
+       return decorator
index 6f8f06714a2e3b4aac7299becdb7ec5249fa887b..456a4d3b2cc50e28ea501e02d69ea79d059f5286 100644 (file)
@@ -27,6 +27,7 @@ from uuid import UUID
 from . import apiv1
 from . import app
 from . import backend
+from . import limiter
 
 from ..packages import Package, File
 
@@ -90,9 +91,8 @@ async def get_filelist_by_uuid(
        return [file async for file in await package.get_files()]
 
 
-# XXX This endpoint need some ratelimiting applied
-
 @app.get("/packages/{uuid:uuid}/download/{path:path}", include_in_schema=False)
+@limiter.limit(limit=100, minutes=60)
 async def download_file(
        path: str,
        package: Package = fastapi.Depends(get_package_by_uuid),
index cf5a68445573f3edcb84d23c316f31a4373663ae..5e1b4f5fcd30b066e716eb1656609f1a3a8a8725 100644 (file)
@@ -20,6 +20,7 @@
 ###############################################################################
 
 import datetime
+import fastapi
 import ipaddress
 
 import sqlalchemy
@@ -79,9 +80,11 @@ class RateLimiter(base.Object):
 
 
 class RateLimiterRequest(base.Object):
-       def init(self, request, handler, *, minutes, limit, key=None):
-               self.request = request
-               self.handler = handler
+       def init(self, request: fastapi.Request, *, minutes, limit, key=None):
+               self.request: fastapi.Request = request
+
+               # Number of requests in the current window
+               self.requests = None
 
                # Save the limits
                self.minutes = minutes
@@ -90,16 +93,14 @@ class RateLimiterRequest(base.Object):
                # Create a default key if none given
                if key is None:
                        key = "%s-%s-%s" % (
-                               self.request.host,
                                self.request.method,
-                               self.request.path,
+                               self.request.url.hostname,
+                               self.request.url.path,
                        )
 
                # Store the key and address
-               self.key        = key
-               self.address    = ipaddress.ip_address(
-                       self.request.remote_ip,
-               )
+               self.key = key
+               self.address, port = self.request.client
 
                # What is the current time?
                self.now = datetime.datetime.utcnow()
@@ -111,19 +112,16 @@ class RateLimiterRequest(base.Object):
                """
                        Returns True if the request is prohibited by the rate limiter
                """
-               requests = await self.get_requests()
+               self.requests = await self.get_requests()
 
                # The client is rate-limited when more requests have been
                # received than allowed.
-               if requests >= self.limit:
+               if self.requests >= self.limit:
                        return True
 
                # Increment the request counter
                await self.increment_requests()
 
-               # If not ratelimited, write some headers
-               self.write_headers(requests=requests)
-
        async def get_requests(self):
                """
                        Returns the number of requests that have been done in the recent sliding window
@@ -153,15 +151,25 @@ class RateLimiterRequest(base.Object):
 
                return await self.db.select_one(stmt, "requests") or 0
 
-       def write_headers(self, requests):
+       def write_headers(self, response: fastapi.Response):
                # Send the limit to the user
-               self.handler.set_header("X-Rate-Limit-Limit", self.limit)
+               response.headers.append("X-Rate-Limit-Limit", "%s" % self.limit)
 
                # Send the user how many requests are left for this time window
-               self.handler.set_header("X-Rate-Limit-Remaining", self.limit - requests)
+               response.headers.append(
+                       "X-Rate-Limit-Remaining", "%s" % (self.limit - self.requests),
+               )
 
                # Send when the limit resets
-               self.handler.set_header("X-Rate-Limit-Reset", self.expires_at.strftime("%s"))
+               response.headers.append(
+                       "X-Rate-Limit-Reset", self.expires_at.strftime("%a, %d %b %Y %H:%M:%S GMT"),
+               )
+
+               # Send Retry-After (in seconds)
+               if self.requests >= self.limit:
+                       response.headers.append(
+                               "Retry-After", "%.0f" % (self.expires_at - self.now).total_seconds(),
+                       )
 
        async def increment_requests(self):
                """