]> git.ipfire.org Git - ipfire.org.git/blob - src/backend/ratelimit.py
wiki: Only match usernames when a word starts with @
[ipfire.org.git] / src / backend / ratelimit.py
1 #!/usr/bin/python
2
3 import datetime
4
5 from . import misc
6
7 class RateLimiter(misc.Object):
8 def handle_request(self, request, handler, minutes, limit):
9 return RateLimiterRequest(self.backend, request, handler,
10 minutes=minutes, limit=limit)
11
12
13 class RateLimiterRequest(misc.Object):
14 def init(self, request, handler, minutes, limit):
15 self.request = request
16 self.handler = handler
17
18 # Save the limits
19 self.minutes = minutes
20 self.limit = limit
21
22 # What is the current time?
23 self.now = datetime.datetime.utcnow()
24
25 # When to expire?
26 self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1)
27
28 self.prefix = "-".join((
29 self.__class__.__name__,
30 self.request.host,
31 self.request.path,
32 self.request.method,
33 self.request.remote_ip,
34 ))
35
36 async def is_ratelimited(self):
37 """
38 Returns True if the request is prohibited by the rate limiter
39 """
40 counter = await self.get_counter()
41
42 # The client is rate-limited when more requests have been
43 # received than allowed.
44 if counter >= self.limit:
45 return True
46
47 # Increment the counter
48 await self.increment_counter()
49
50 # If not ratelimited, write some headers
51 self.write_headers(counter=counter)
52
53 @property
54 def key(self):
55 return "%s-%s" % (self.prefix, self.now.strftime("%Y-%m-%d-%H:%M"))
56
57 @property
58 def keys_to_check(self):
59 for minute in range(self.minutes + 1):
60 when = self.now - datetime.timedelta(minutes=minute)
61
62 yield "%s-%s" % (self.prefix, when.strftime("%Y-%m-%d-%H:%M"))
63
64 async def get_counter(self):
65 """
66 Returns the number of requests that have been done in
67 recent time.
68 """
69 async with await self.backend.cache.pipeline() as p:
70 for key in self.keys_to_check:
71 await p.get(key)
72
73 # Run the pipeline
74 res = await p.execute()
75
76 # Return the sum
77 return sum((int(e) for e in res if e))
78
79 def write_headers(self, counter):
80 # Send the limit to the user
81 self.handler.set_header("X-Rate-Limit-Limit", self.limit)
82
83 # Send the user how many requests are left for this time window
84 self.handler.set_header("X-Rate-Limit-Remaining", self.limit - counter)
85
86 # Send when the limit resets
87 self.handler.set_header("X-Rate-Limit-Reset", self.expires_at.strftime("%s"))
88
89 async def increment_counter(self):
90 async with await self.backend.cache.pipeline() as p:
91 # Increment the key
92 await p.incr(self.key)
93
94 # Set expiry
95 await p.expireat(self.key, self.expires_at)
96
97 # Run the pipeline
98 await p.execute()