]> git.ipfire.org Git - ipfire.org.git/blame - src/backend/ratelimit.py
wiki: Only match usernames when a word starts with @
[ipfire.org.git] / src / backend / ratelimit.py
CommitLineData
372ef119
MT
1#!/usr/bin/python
2
3import datetime
4
5from . import misc
6
7class RateLimiter(misc.Object):
8 def handle_request(self, request, handler, minutes, limit):
9 return RateLimiterRequest(self.backend, request, handler,
10 minutes=minutes, limit=limit)
11
12
13class RateLimiterRequest(misc.Object):
372ef119
MT
14 def init(self, request, handler, minutes, limit):
15 self.request = request
16 self.handler = handler
17
18 # Save the limits
19 self.minutes = minutes
20 self.limit = limit
21
0056fdbf 22 # What is the current time?
372ef119
MT
23 self.now = datetime.datetime.utcnow()
24
0056fdbf
MT
25 # When to expire?
26 self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1)
372ef119 27
0056fdbf
MT
28 self.prefix = "-".join((
29 self.__class__.__name__,
30 self.request.host,
31 self.request.path,
32 self.request.method,
33 self.request.remote_ip,
34 ))
372ef119 35
0056fdbf 36 async def is_ratelimited(self):
372ef119
MT
37 """
38 Returns True if the request is prohibited by the rate limiter
39 """
0056fdbf
MT
40 counter = await self.get_counter()
41
372ef119
MT
42 # The client is rate-limited when more requests have been
43 # received than allowed.
0056fdbf
MT
44 if counter >= self.limit:
45 return True
46
47 # Increment the counter
48 await self.increment_counter()
49
50 # If not ratelimited, write some headers
51 self.write_headers(counter=counter)
52
53 @property
54 def key(self):
55 return "%s-%s" % (self.prefix, self.now.strftime("%Y-%m-%d-%H:%M"))
56
57 @property
58 def keys_to_check(self):
59 for minute in range(self.minutes + 1):
60 when = self.now - datetime.timedelta(minutes=minute)
61
62 yield "%s-%s" % (self.prefix, when.strftime("%Y-%m-%d-%H:%M"))
372ef119 63
0056fdbf 64 async def get_counter(self):
372ef119
MT
65 """
66 Returns the number of requests that have been done in
67 recent time.
68 """
0056fdbf
MT
69 async with await self.backend.cache.pipeline() as p:
70 for key in self.keys_to_check:
71 await p.get(key)
372ef119 72
0056fdbf
MT
73 # Run the pipeline
74 res = await p.execute()
372ef119 75
0056fdbf
MT
76 # Return the sum
77 return sum((int(e) for e in res if e))
372ef119 78
0056fdbf 79 def write_headers(self, counter):
372ef119
MT
80 # Send the limit to the user
81 self.handler.set_header("X-Rate-Limit-Limit", self.limit)
82
83 # Send the user how many requests are left for this time window
0056fdbf 84 self.handler.set_header("X-Rate-Limit-Remaining", self.limit - counter)
372ef119 85
0056fdbf
MT
86 # Send when the limit resets
87 self.handler.set_header("X-Rate-Limit-Reset", self.expires_at.strftime("%s"))
372ef119 88
0056fdbf
MT
89 async def increment_counter(self):
90 async with await self.backend.cache.pipeline() as p:
91 # Increment the key
92 await p.incr(self.key)
372ef119 93
0056fdbf
MT
94 # Set expiry
95 await p.expireat(self.key, self.expires_at)
372ef119 96
0056fdbf
MT
97 # Run the pipeline
98 await p.execute()