]>
Commit | Line | Data |
---|---|---|
372ef119 MT |
1 | #!/usr/bin/python |
2 | ||
3 | import datetime | |
4 | ||
5 | from . import misc | |
6 | ||
7 | class RateLimiter(misc.Object): | |
8 | def handle_request(self, request, handler, minutes, limit): | |
9 | return RateLimiterRequest(self.backend, request, handler, | |
10 | minutes=minutes, limit=limit) | |
11 | ||
12 | ||
13 | class RateLimiterRequest(misc.Object): | |
372ef119 MT |
14 | def init(self, request, handler, minutes, limit): |
15 | self.request = request | |
16 | self.handler = handler | |
17 | ||
18 | # Save the limits | |
19 | self.minutes = minutes | |
20 | self.limit = limit | |
21 | ||
0056fdbf | 22 | # What is the current time? |
372ef119 MT |
23 | self.now = datetime.datetime.utcnow() |
24 | ||
0056fdbf MT |
25 | # When to expire? |
26 | self.expires_at = self.now + datetime.timedelta(minutes=self.minutes + 1) | |
372ef119 | 27 | |
0056fdbf MT |
28 | self.prefix = "-".join(( |
29 | self.__class__.__name__, | |
30 | self.request.host, | |
31 | self.request.path, | |
32 | self.request.method, | |
33 | self.request.remote_ip, | |
34 | )) | |
372ef119 | 35 | |
0056fdbf | 36 | async def is_ratelimited(self): |
372ef119 MT |
37 | """ |
38 | Returns True if the request is prohibited by the rate limiter | |
39 | """ | |
0056fdbf MT |
40 | counter = await self.get_counter() |
41 | ||
372ef119 MT |
42 | # The client is rate-limited when more requests have been |
43 | # received than allowed. | |
0056fdbf MT |
44 | if counter >= self.limit: |
45 | return True | |
46 | ||
47 | # Increment the counter | |
48 | await self.increment_counter() | |
49 | ||
50 | # If not ratelimited, write some headers | |
51 | self.write_headers(counter=counter) | |
52 | ||
53 | @property | |
54 | def key(self): | |
55 | return "%s-%s" % (self.prefix, self.now.strftime("%Y-%m-%d-%H:%M")) | |
56 | ||
57 | @property | |
58 | def keys_to_check(self): | |
59 | for minute in range(self.minutes + 1): | |
60 | when = self.now - datetime.timedelta(minutes=minute) | |
61 | ||
62 | yield "%s-%s" % (self.prefix, when.strftime("%Y-%m-%d-%H:%M")) | |
372ef119 | 63 | |
0056fdbf | 64 | async def get_counter(self): |
372ef119 MT |
65 | """ |
66 | Returns the number of requests that have been done in | |
67 | recent time. | |
68 | """ | |
0056fdbf MT |
69 | async with await self.backend.cache.pipeline() as p: |
70 | for key in self.keys_to_check: | |
71 | await p.get(key) | |
372ef119 | 72 | |
0056fdbf MT |
73 | # Run the pipeline |
74 | res = await p.execute() | |
372ef119 | 75 | |
0056fdbf MT |
76 | # Return the sum |
77 | return sum((int(e) for e in res if e)) | |
372ef119 | 78 | |
0056fdbf | 79 | def write_headers(self, counter): |
372ef119 MT |
80 | # Send the limit to the user |
81 | self.handler.set_header("X-Rate-Limit-Limit", self.limit) | |
82 | ||
83 | # Send the user how many requests are left for this time window | |
0056fdbf | 84 | self.handler.set_header("X-Rate-Limit-Remaining", self.limit - counter) |
372ef119 | 85 | |
0056fdbf MT |
86 | # Send when the limit resets |
87 | self.handler.set_header("X-Rate-Limit-Reset", self.expires_at.strftime("%s")) | |
372ef119 | 88 | |
0056fdbf MT |
89 | async def increment_counter(self): |
90 | async with await self.backend.cache.pipeline() as p: | |
91 | # Increment the key | |
92 | await p.incr(self.key) | |
372ef119 | 93 | |
0056fdbf MT |
94 | # Set expiry |
95 | await p.expireat(self.key, self.expires_at) | |
372ef119 | 96 | |
0056fdbf MT |
97 | # Run the pipeline |
98 | await p.execute() |