]> git.ipfire.org Git - ipfire.org.git/blob - src/backend/ratelimit.py
Deploy rate-limiting
[ipfire.org.git] / src / backend / ratelimit.py
1 #!/usr/bin/python
2
3 import datetime
4
5 from . import misc
6
7 class RateLimiter(misc.Object):
8 def handle_request(self, request, handler, minutes, limit):
9 return RateLimiterRequest(self.backend, request, handler,
10 minutes=minutes, limit=limit)
11
12
13 class RateLimiterRequest(misc.Object):
14 prefix = "ratelimit"
15
16 def init(self, request, handler, minutes, limit):
17 self.request = request
18 self.handler = handler
19
20 # Save the limits
21 self.minutes = minutes
22 self.limit = limit
23
24 self.now = datetime.datetime.utcnow()
25
26 # Fetch the current counter value from the cache
27 self.counter = self.get_counter()
28
29 # Increment the rate-limiting counter
30 self.increment_counter()
31
32 # Write the header if we are not limited
33 if not self.is_ratelimited():
34 self.write_headers()
35
36 def is_ratelimited(self):
37 """
38 Returns True if the request is prohibited by the rate limiter
39 """
40 # The client is rate-limited when more requests have been
41 # received than allowed.
42 return self.counter >= self.limit
43
44 def get_counter(self):
45 """
46 Returns the number of requests that have been done in
47 recent time.
48 """
49 keys = self.get_keys_to_check()
50
51 res = self.memcache.get_multi(keys)
52 if res:
53 return sum((int(e) for e in res.values()))
54
55 return 0
56
57 def write_headers(self):
58 # Send the limit to the user
59 self.handler.set_header("X-Rate-Limit-Limit", self.limit)
60
61 # Send the user how many requests are left for this time window
62 self.handler.set_header("X-Rate-Limit-Remaining",
63 self.limit - self.counter)
64
65 expires = self.now + datetime.timedelta(seconds=self.expires_after)
66 self.handler.set_header("X-Rate-Limit-Reset", expires.strftime("%s"))
67
68 def get_key(self):
69 key_prefix = self.get_key_prefix()
70
71 return "%s-%s" % (key_prefix, self.now.strftime("%Y-%m-%d-%H:%M"))
72
73 def get_keys_to_check(self):
74 key_prefix = self.get_key_prefix()
75
76 keys = []
77 for minute in range(self.minutes + 1):
78 when = self.now - datetime.timedelta(minutes=minute)
79
80 key = "%s-%s" % (key_prefix, when.strftime("%Y-%m-%d-%H:%M"))
81 keys.append(key)
82
83 return keys
84
85 def get_key_prefix(self):
86 return "-".join((self.prefix, self.request.host, self.request.path,
87 self.request.method, self.request.remote_ip,))
88
89 def increment_counter(self):
90 key = self.get_key()
91
92 # Add the key or increment if it already exists
93 if not self.memcache.add(key, "1", self.expires_after):
94 self.memcache.incr(key)
95
96 @property
97 def expires_after(self):
98 """
99 Returns the number of seconds after which the counter has reset.
100 """
101 return (self.minutes + 1) * 60