]>
Commit | Line | Data |
---|---|---|
879aa787 MT |
1 | # Written by Bram Cohen |
2 | # see LICENSE.txt for license information | |
3 | ||
4 | import socket | |
5 | from errno import EWOULDBLOCK, ECONNREFUSED, EHOSTUNREACH | |
6 | try: | |
7 | from select import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP | |
8 | timemult = 1000 | |
9 | except ImportError: | |
10 | from selectpoll import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP | |
11 | timemult = 1 | |
12 | from time import sleep | |
13 | from clock import clock | |
14 | import sys | |
15 | from random import shuffle, randrange | |
16 | from natpunch import UPnP_open_port, UPnP_close_port | |
17 | # from BT1.StreamCheck import StreamCheck | |
18 | # import inspect | |
19 | try: | |
20 | True | |
21 | except: | |
22 | True = 1 | |
23 | False = 0 | |
24 | ||
25 | all = POLLIN | POLLOUT | |
26 | ||
27 | UPnP_ERROR = "unable to forward port via UPnP" | |
28 | ||
29 | class SingleSocket: | |
30 | def __init__(self, socket_handler, sock, handler, ip = None): | |
31 | self.socket_handler = socket_handler | |
32 | self.socket = sock | |
33 | self.handler = handler | |
34 | self.buffer = [] | |
35 | self.last_hit = clock() | |
36 | self.fileno = sock.fileno() | |
37 | self.connected = False | |
38 | self.skipped = 0 | |
39 | # self.check = StreamCheck() | |
40 | try: | |
41 | self.ip = self.socket.getpeername()[0] | |
42 | except: | |
43 | if ip is None: | |
44 | self.ip = 'unknown' | |
45 | else: | |
46 | self.ip = ip | |
47 | ||
48 | def get_ip(self, real=False): | |
49 | if real: | |
50 | try: | |
51 | self.ip = self.socket.getpeername()[0] | |
52 | except: | |
53 | pass | |
54 | return self.ip | |
55 | ||
56 | def close(self): | |
57 | ''' | |
58 | for x in xrange(5,0,-1): | |
59 | try: | |
60 | f = inspect.currentframe(x).f_code | |
61 | print (f.co_filename,f.co_firstlineno,f.co_name) | |
62 | del f | |
63 | except: | |
64 | pass | |
65 | print '' | |
66 | ''' | |
67 | assert self.socket | |
68 | self.connected = False | |
69 | sock = self.socket | |
70 | self.socket = None | |
71 | self.buffer = [] | |
72 | del self.socket_handler.single_sockets[self.fileno] | |
73 | self.socket_handler.poll.unregister(sock) | |
74 | sock.close() | |
75 | ||
76 | def shutdown(self, val): | |
77 | self.socket.shutdown(val) | |
78 | ||
79 | def is_flushed(self): | |
80 | return not self.buffer | |
81 | ||
82 | def write(self, s): | |
83 | # self.check.write(s) | |
84 | assert self.socket is not None | |
85 | self.buffer.append(s) | |
86 | if len(self.buffer) == 1: | |
87 | self.try_write() | |
88 | ||
89 | def try_write(self): | |
90 | if self.connected: | |
91 | dead = False | |
92 | try: | |
93 | while self.buffer: | |
94 | buf = self.buffer[0] | |
95 | amount = self.socket.send(buf) | |
96 | if amount == 0: | |
97 | self.skipped += 1 | |
98 | break | |
99 | self.skipped = 0 | |
100 | if amount != len(buf): | |
101 | self.buffer[0] = buf[amount:] | |
102 | break | |
103 | del self.buffer[0] | |
104 | except socket.error, e: | |
105 | try: | |
106 | dead = e[0] != EWOULDBLOCK | |
107 | except: | |
108 | dead = True | |
109 | self.skipped += 1 | |
110 | if self.skipped >= 3: | |
111 | dead = True | |
112 | if dead: | |
113 | self.socket_handler.dead_from_write.append(self) | |
114 | return | |
115 | if self.buffer: | |
116 | self.socket_handler.poll.register(self.socket, all) | |
117 | else: | |
118 | self.socket_handler.poll.register(self.socket, POLLIN) | |
119 | ||
120 | def set_handler(self, handler): | |
121 | self.handler = handler | |
122 | ||
123 | class SocketHandler: | |
124 | def __init__(self, timeout, ipv6_enable, readsize = 100000): | |
125 | self.timeout = timeout | |
126 | self.ipv6_enable = ipv6_enable | |
127 | self.readsize = readsize | |
128 | self.poll = poll() | |
129 | # {socket: SingleSocket} | |
130 | self.single_sockets = {} | |
131 | self.dead_from_write = [] | |
132 | self.max_connects = 1000 | |
133 | self.port_forwarded = None | |
134 | self.servers = {} | |
135 | ||
136 | def scan_for_timeouts(self): | |
137 | t = clock() - self.timeout | |
138 | tokill = [] | |
139 | for s in self.single_sockets.values(): | |
140 | if s.last_hit < t: | |
141 | tokill.append(s) | |
142 | for k in tokill: | |
143 | if k.socket is not None: | |
144 | self._close_socket(k) | |
145 | ||
146 | def bind(self, port, bind = '', reuse = False, ipv6_socket_style = 1, upnp = 0): | |
147 | port = int(port) | |
148 | addrinfos = [] | |
149 | self.servers = {} | |
150 | self.interfaces = [] | |
151 | # if bind != "" thread it as a comma seperated list and bind to all | |
152 | # addresses (can be ips or hostnames) else bind to default ipv6 and | |
153 | # ipv4 address | |
154 | if bind: | |
155 | if self.ipv6_enable: | |
156 | socktype = socket.AF_UNSPEC | |
157 | else: | |
158 | socktype = socket.AF_INET | |
159 | bind = bind.split(',') | |
160 | for addr in bind: | |
161 | if sys.version_info < (2,2): | |
162 | addrinfos.append((socket.AF_INET, None, None, None, (addr, port))) | |
163 | else: | |
164 | addrinfos.extend(socket.getaddrinfo(addr, port, | |
165 | socktype, socket.SOCK_STREAM)) | |
166 | else: | |
167 | if self.ipv6_enable: | |
168 | addrinfos.append([socket.AF_INET6, None, None, None, ('', port)]) | |
169 | if not addrinfos or ipv6_socket_style != 0: | |
170 | addrinfos.append([socket.AF_INET, None, None, None, ('', port)]) | |
171 | for addrinfo in addrinfos: | |
172 | try: | |
173 | server = socket.socket(addrinfo[0], socket.SOCK_STREAM) | |
174 | if reuse: | |
175 | server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
176 | server.setblocking(0) | |
177 | server.bind(addrinfo[4]) | |
178 | self.servers[server.fileno()] = server | |
179 | if bind: | |
180 | self.interfaces.append(server.getsockname()[0]) | |
181 | server.listen(64) | |
182 | self.poll.register(server, POLLIN) | |
183 | except socket.error, e: | |
184 | for server in self.servers.values(): | |
185 | try: | |
186 | server.close() | |
187 | except: | |
188 | pass | |
189 | if self.ipv6_enable and ipv6_socket_style == 0 and self.servers: | |
190 | raise socket.error('blocked port (may require ipv6_binds_v4 to be set)') | |
191 | raise socket.error(str(e)) | |
192 | if not self.servers: | |
193 | raise socket.error('unable to open server port') | |
194 | if upnp: | |
195 | if not UPnP_open_port(port): | |
196 | for server in self.servers.values(): | |
197 | try: | |
198 | server.close() | |
199 | except: | |
200 | pass | |
201 | self.servers = None | |
202 | self.interfaces = None | |
203 | raise socket.error(UPnP_ERROR) | |
204 | self.port_forwarded = port | |
205 | self.port = port | |
206 | ||
207 | def find_and_bind(self, minport, maxport, bind = '', reuse = False, | |
208 | ipv6_socket_style = 1, upnp = 0, randomizer = False): | |
209 | e = 'maxport less than minport - no ports to check' | |
210 | if maxport-minport < 50 or not randomizer: | |
211 | portrange = range(minport, maxport+1) | |
212 | if randomizer: | |
213 | shuffle(portrange) | |
214 | portrange = portrange[:20] # check a maximum of 20 ports | |
215 | else: | |
216 | portrange = [] | |
217 | while len(portrange) < 20: | |
218 | listen_port = randrange(minport, maxport+1) | |
219 | if not listen_port in portrange: | |
220 | portrange.append(listen_port) | |
221 | for listen_port in portrange: | |
222 | try: | |
223 | self.bind(listen_port, bind, | |
224 | ipv6_socket_style = ipv6_socket_style, upnp = upnp) | |
225 | return listen_port | |
226 | except socket.error, e: | |
227 | pass | |
228 | raise socket.error(str(e)) | |
229 | ||
230 | ||
231 | def set_handler(self, handler): | |
232 | self.handler = handler | |
233 | ||
234 | ||
235 | def start_connection_raw(self, dns, socktype = socket.AF_INET, handler = None): | |
236 | if handler is None: | |
237 | handler = self.handler | |
238 | sock = socket.socket(socktype, socket.SOCK_STREAM) | |
239 | sock.setblocking(0) | |
240 | try: | |
241 | sock.connect_ex(dns) | |
242 | except socket.error: | |
243 | raise | |
244 | except Exception, e: | |
245 | raise socket.error(str(e)) | |
246 | self.poll.register(sock, POLLIN) | |
247 | s = SingleSocket(self, sock, handler, dns[0]) | |
248 | self.single_sockets[sock.fileno()] = s | |
249 | return s | |
250 | ||
251 | ||
252 | def start_connection(self, dns, handler = None, randomize = False): | |
253 | if handler is None: | |
254 | handler = self.handler | |
255 | if sys.version_info < (2,2): | |
256 | s = self.start_connection_raw(dns,socket.AF_INET,handler) | |
257 | else: | |
258 | if self.ipv6_enable: | |
259 | socktype = socket.AF_UNSPEC | |
260 | else: | |
261 | socktype = socket.AF_INET | |
262 | try: | |
263 | addrinfos = socket.getaddrinfo(dns[0], int(dns[1]), | |
264 | socktype, socket.SOCK_STREAM) | |
265 | except socket.error, e: | |
266 | raise | |
267 | except Exception, e: | |
268 | raise socket.error(str(e)) | |
269 | if randomize: | |
270 | shuffle(addrinfos) | |
271 | for addrinfo in addrinfos: | |
272 | try: | |
273 | s = self.start_connection_raw(addrinfo[4],addrinfo[0],handler) | |
274 | break | |
275 | except: | |
276 | pass | |
277 | else: | |
278 | raise socket.error('unable to connect') | |
279 | return s | |
280 | ||
281 | ||
282 | def _sleep(self): | |
283 | sleep(1) | |
284 | ||
285 | def handle_events(self, events): | |
286 | for sock, event in events: | |
287 | s = self.servers.get(sock) | |
288 | if s: | |
289 | if event & (POLLHUP | POLLERR) != 0: | |
290 | self.poll.unregister(s) | |
291 | s.close() | |
292 | del self.servers[sock] | |
293 | print "lost server socket" | |
294 | elif len(self.single_sockets) < self.max_connects: | |
295 | try: | |
296 | newsock, addr = s.accept() | |
297 | newsock.setblocking(0) | |
298 | nss = SingleSocket(self, newsock, self.handler) | |
299 | self.single_sockets[newsock.fileno()] = nss | |
300 | self.poll.register(newsock, POLLIN) | |
301 | self.handler.external_connection_made(nss) | |
302 | except socket.error: | |
303 | self._sleep() | |
304 | else: | |
305 | s = self.single_sockets.get(sock) | |
306 | if not s: | |
307 | continue | |
308 | s.connected = True | |
309 | if (event & (POLLHUP | POLLERR)): | |
310 | self._close_socket(s) | |
311 | continue | |
312 | if (event & POLLIN): | |
313 | try: | |
314 | s.last_hit = clock() | |
315 | data = s.socket.recv(100000) | |
316 | if not data: | |
317 | self._close_socket(s) | |
318 | else: | |
319 | s.handler.data_came_in(s, data) | |
320 | except socket.error, e: | |
321 | code, msg = e | |
322 | if code != EWOULDBLOCK: | |
323 | self._close_socket(s) | |
324 | continue | |
325 | if (event & POLLOUT) and s.socket and not s.is_flushed(): | |
326 | s.try_write() | |
327 | if s.is_flushed(): | |
328 | s.handler.connection_flushed(s) | |
329 | ||
330 | def close_dead(self): | |
331 | while self.dead_from_write: | |
332 | old = self.dead_from_write | |
333 | self.dead_from_write = [] | |
334 | for s in old: | |
335 | if s.socket: | |
336 | self._close_socket(s) | |
337 | ||
338 | def _close_socket(self, s): | |
339 | s.close() | |
340 | s.handler.connection_lost(s) | |
341 | ||
342 | def do_poll(self, t): | |
343 | r = self.poll.poll(t*timemult) | |
344 | if r is None: | |
345 | connects = len(self.single_sockets) | |
346 | to_close = int(connects*0.05)+1 # close 5% of sockets | |
347 | self.max_connects = connects-to_close | |
348 | closelist = self.single_sockets.values() | |
349 | shuffle(closelist) | |
350 | closelist = closelist[:to_close] | |
351 | for sock in closelist: | |
352 | self._close_socket(sock) | |
353 | return [] | |
354 | return r | |
355 | ||
356 | def get_stats(self): | |
357 | return { 'interfaces': self.interfaces, | |
358 | 'port': self.port, | |
359 | 'upnp': self.port_forwarded is not None } | |
360 | ||
361 | ||
362 | def shutdown(self): | |
363 | for ss in self.single_sockets.values(): | |
364 | try: | |
365 | ss.close() | |
366 | except: | |
367 | pass | |
368 | for server in self.servers.values(): | |
369 | try: | |
370 | server.close() | |
371 | except: | |
372 | pass | |
373 | if self.port_forwarded is not None: | |
374 | UPnP_close_port(self.port_forwarded) | |
375 |