From f395df091e8bd00ce9df26fec35d8e2a0b77a41d Mon Sep 17 00:00:00 2001 From: Jeremy Hylton Date: Tue, 2 Jul 2002 20:42:50 +0000 Subject: [PATCH] Backport various bug fixes from trunk. The 2.1 maintenance branch is now identical to the trunk through rev 1.54 of httplib.py. --- Lib/httplib.py | 275 +++++++++++++++++++++++++++------------ Lib/test/test_httplib.py | 4 +- Lib/urlparse.py | 84 +++++++----- 3 files changed, 246 insertions(+), 117 deletions(-) diff --git a/Lib/httplib.py b/Lib/httplib.py index db704b39f396..1540ba724905 100644 --- a/Lib/httplib.py +++ b/Lib/httplib.py @@ -66,8 +66,10 @@ Req-started-unread-response _CS_REQ_STARTED Req-sent-unread-response _CS_REQ_SENT """ -import socket +import errno import mimetools +import socket +from urlparse import urlsplit try: from cStringIO import StringIO @@ -76,10 +78,10 @@ except ImportError: __all__ = ["HTTP", "HTTPResponse", "HTTPConnection", "HTTPSConnection", "HTTPException", "NotConnected", "UnknownProtocol", - "UnknownTransferEncoding", "IllegalKeywordArgument", - "UnimplementedFileMode", "IncompleteRead", - "ImproperConnectionState", "CannotSendRequest", "CannotSendHeader", - "ResponseNotReady", "BadStatusLine", "error"] + "UnknownTransferEncoding", "UnimplementedFileMode", + "IncompleteRead", "InvalidURL", "ImproperConnectionState", + "CannotSendRequest", "CannotSendHeader", "ResponseNotReady", + "BadStatusLine", "error"] HTTP_PORT = 80 HTTPS_PORT = 443 @@ -109,11 +111,7 @@ class HTTPResponse: self.length = _UNKNOWN # number of bytes left in response self.will_close = _UNKNOWN # conn will close at end of response - def begin(self): - if self.msg is not None: - # we've already started reading the response - return - + def _read_status(self): line = self.fp.readline() if self.debuglevel > 0: print "reply:", repr(line) @@ -133,13 +131,33 @@ class HTTPResponse: # The status code is a three-digit number try: - self.status = status = int(status) + status = int(status) if status < 100 or status > 999: raise BadStatusLine(line) except ValueError: raise BadStatusLine(line) - self.reason = reason.strip() + return version, status, reason + def _begin(self): + if self.msg is not None: + # we've already started reading the response + return + + # read until we get a non-100 response + while 1: + version, status, reason = self._read_status() + if status != 100: + break + # skip the header from the 100 response + while 1: + skip = self.fp.readline().strip() + if not skip: + break + if self.debuglevel > 0: + print "header:", skip + + self.status = status + self.reason = reason.strip() if version == 'HTTP/1.0': self.version = 10 elif version.startswith('HTTP/1.'): @@ -150,6 +168,7 @@ class HTTPResponse: raise UnknownProtocol(version) if self.version == 9: + self.chunked = 0 self.msg = mimetools.Message(StringIO()) return @@ -231,6 +250,7 @@ class HTTPResponse: return '' if self.chunked: + assert self.chunked != _UNKNOWN chunk_left = self.chunk_left value = '' while 1: @@ -345,7 +365,10 @@ class HTTPConnection: if port is None: i = host.find(':') if i >= 0: - port = int(host[i+1:]) + try: + port = int(host[i+1:]) + except ValueError: + raise InvalidURL("nonnumeric port: '%s'" % host[i+1:]) host = host[:i] else: port = self.default_port @@ -394,7 +417,7 @@ class HTTPConnection: self.close() raise - def putrequest(self, method, url): + def putrequest(self, method, url, skip_host=0): """Send a request to the server. `method' specifies an HTTP request method, e.g. 'GET'. @@ -445,18 +468,31 @@ class HTTPConnection: if self._http_vsn == 11: # Issue some standard headers for better HTTP/1.1 compliance - # this header is issued *only* for HTTP/1.1 connections. more - # specifically, this means it is only issued when the client uses - # the new HTTPConnection() class. backwards-compat clients will - # be using HTTP/1.0 and those clients may be issuing this header - # themselves. we should NOT issue it twice; some web servers (such - # as Apache) barf when they see two Host: headers - - # if we need a non-standard port,include it in the header - if self.port == HTTP_PORT: - self.putheader('Host', self.host) - else: - self.putheader('Host', "%s:%s" % (self.host, self.port)) + if not skip_host: + # this header is issued *only* for HTTP/1.1 + # connections. more specifically, this means it is + # only issued when the client uses the new + # HTTPConnection() class. backwards-compat clients + # will be using HTTP/1.0 and those clients may be + # issuing this header themselves. we should NOT issue + # it twice; some web servers (such as Apache) barf + # when they see two Host: headers + + # If we need a non-standard port,include it in the + # header. If the request is going through a proxy, + # but the host of the actual URL, not the host of the + # proxy. + + netloc = '' + if url.startswith('http'): + nil, netloc, nil, nil, nil = urlsplit(url) + + if netloc: + self.putheader('Host', netloc) + elif self.port == HTTP_PORT: + self.putheader('Host', self.host) + else: + self.putheader('Host', "%s:%s" % (self.host, self.port)) # note: we are assuming that clients will not attempt to set these # headers since *this* library must deal with the @@ -514,7 +550,14 @@ class HTTPConnection: self._send_request(method, url, body, headers) def _send_request(self, method, url, body, headers): - self.putrequest(method, url) + # If headers already contains a host header, then define the + # optional skip_host argument to putrequest(). The check is + # harder because field names are case insensitive. + if 'Host' in (headers + or [k for k in headers.iterkeys() if k.lower() == "host"]): + self.putrequest(method, url, skip_host=1) + else: + self.putrequest(method, url) if body: self.putheader('Content-Length', str(len(body))) @@ -556,7 +599,8 @@ class HTTPConnection: else: response = self.response_class(self.sock) - response.begin() + response._begin() + assert response.will_close != _UNKNOWN self.__state = _CS_IDLE if response.will_close: @@ -568,6 +612,83 @@ class HTTPConnection: return response +class SSLFile: + """File-like object wrapping an SSL socket.""" + + BUFSIZE = 8192 + + def __init__(self, sock, ssl, bufsize=None): + self._sock = sock + self._ssl = ssl + self._buf = '' + self._bufsize = bufsize or self.__class__.BUFSIZE + + def _read(self): + buf = '' + # put in a loop so that we retry on transient errors + while 1: + try: + buf = self._ssl.read(self._bufsize) + except socket.sslerror, err: + if (err[0] == socket.SSL_ERROR_WANT_READ + or err[0] == socket.SSL_ERROR_WANT_WRITE): + continue + if (err[0] == socket.SSL_ERROR_ZERO_RETURN + or err[0] == socket.SSL_ERROR_EOF): + break + raise + except socket.error, err: + if err[0] == errno.EINTR: + continue + if err[0] == errno.EBADF: + # XXX socket was closed? + break + raise + else: + break + return buf + + def read(self, size=None): + L = [self._buf] + avail = len(self._buf) + while size is None or avail < size: + s = self._read() + if s == '': + break + L.append(s) + avail += len(s) + all = "".join(L) + if size is None: + self._buf = '' + return all + else: + self._buf = all[size:] + return all[:size] + + def readline(self): + L = [self._buf] + self._buf = '' + while 1: + i = L[-1].find("\n") + if i >= 0: + break + s = self._read() + if s == '': + break + L.append(s) + if i == -1: + # loop exited because there is no more data + return "".join(L) + else: + all = "".join(L) + # XXX could do enough bookkeeping not to do a 2nd search + i = all.find("\n") + 1 + line = all[:i] + self._buf = all[i:] + return line + + def close(self): + self._sock.close() class FakeSocket: def __init__(self, sock, ssl): @@ -575,27 +696,9 @@ class FakeSocket: self.__ssl = ssl def makefile(self, mode, bufsize=None): - """Return a readable file-like object with data from socket. - - This method offers only partial support for the makefile - interface of a real socket. It only supports modes 'r' and - 'rb' and the bufsize argument is ignored. - - The returned object contains *all* of the file data - """ if mode != 'r' and mode != 'rb': raise UnimplementedFileMode() - - msgbuf = [] - while 1: - try: - buf = self.__ssl.read() - except socket.sslerror, msg: - break - if buf == '': - break - msgbuf.append(buf) - return StringIO("".join(msgbuf)) + return SSLFile(self.__sock, self.__ssl, bufsize) def send(self, stuff, flags = 0): return self.__ssl.write(stuff) @@ -615,21 +718,10 @@ class HTTPSConnection(HTTPConnection): default_port = HTTPS_PORT - def __init__(self, host, port=None, **x509): - keys = x509.keys() - try: - keys.remove('key_file') - except ValueError: - pass - try: - keys.remove('cert_file') - except ValueError: - pass - if keys: - raise IllegalKeywordArgument() + def __init__(self, host, port=None, key_file=None, cert_file=None): HTTPConnection.__init__(self, host, port) - self.key_file = x509.get('key_file') - self.cert_file = x509.get('cert_file') + self.key_file = key_file + self.cert_file = cert_file def connect(self): "Connect to a host on a given (SSL) port." @@ -653,7 +745,7 @@ class HTTP: _connection_class = HTTPConnection - def __init__(self, host='', port=None, **x509): + def __init__(self, host='', port=None): "Provide a default host, since the superclass requires one." # some joker passed 0 explicitly, meaning default port @@ -663,18 +755,19 @@ class HTTP: # Note that we may pass an empty string as the host; this will throw # an error when we attempt to connect. Presumably, the client code # will call connect before then, with a proper host. - self._conn = self._connection_class(host, port) + self._setup(self._connection_class(host, port)) + + def _setup(self, conn): + self._conn = conn + # set up delegation to flesh out interface - self.send = self._conn.send - self.putrequest = self._conn.putrequest - self.endheaders = self._conn.endheaders - self._conn._http_vsn = self._http_vsn - self._conn._http_vsn_str = self._http_vsn_str + self.send = conn.send + self.putrequest = conn.putrequest + self.endheaders = conn.endheaders + self.set_debuglevel = conn.set_debuglevel - # we never actually use these for anything, but we keep them here for - # compatibility with post-1.5.2 CVS. - self.key_file = x509.get('key_file') - self.cert_file = x509.get('cert_file') + conn._http_vsn = self._http_vsn + conn._http_vsn_str = self._http_vsn_str self.file = None @@ -685,9 +778,6 @@ class HTTP: self._conn._set_hostport(host, port) self._conn.connect() - def set_debuglevel(self, debuglevel): - self._conn.set_debuglevel(debuglevel) - def getfile(self): "Provide a getfile, since the superclass' does not use this concept." return self.file @@ -745,6 +835,19 @@ if hasattr(socket, 'ssl'): _connection_class = HTTPSConnection + def __init__(self, host='', port=None, **x509): + # provide a default host, pass the X509 cert info + + # urf. compensate for bad input. + if port == 0: + port = None + self._setup(self._connection_class(host, port, **x509)) + + # we never actually use these for anything, but we keep them + # here for compatibility with post-1.5.2 CVS. + self.key_file = x509.get('key_file') + self.cert_file = x509.get('cert_file') + class HTTPException(Exception): pass @@ -752,6 +855,9 @@ class HTTPException(Exception): class NotConnected(HTTPException): pass +class InvalidURL(HTTPException): + pass + class UnknownProtocol(HTTPException): def __init__(self, version): self.version = version @@ -759,9 +865,6 @@ class UnknownProtocol(HTTPException): class UnknownTransferEncoding(HTTPException): pass -class IllegalKeywordArgument(HTTPException): - pass - class UnimplementedFileMode(HTTPException): pass @@ -822,7 +925,18 @@ def test(): if headers: for header in headers.headers: print header.strip() print - print h.getfile().read() + print "read", len(h.getfile().read()) + + # minimal test that code to extract host from url works + class HTTP11(HTTP): + _http_vsn = 11 + _http_vsn_str = 'HTTP/1.1' + + h = HTTP11('www.python.org') + h.putrequest('GET', 'http://www.python.org/~jeremy/') + h.endheaders() + h.getreply() + h.close() if hasattr(socket, 'ssl'): host = 'sourceforge.net' @@ -832,13 +946,14 @@ def test(): hs.putrequest('GET', selector) hs.endheaders() status, reason, headers = hs.getreply() + # XXX why does this give a 302 response? print 'status =', status print 'reason =', reason print if headers: for header in headers.headers: print header.strip() print - print hs.getfile().read() + print "read", len(hs.getfile().read()) if __name__ == '__main__': diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index aef65a68112e..bc49d62468ea 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -16,7 +16,7 @@ class FakeSocket: body = "HTTP/1.1 200 Ok\r\n\r\nText" sock = FakeSocket(body) resp = httplib.HTTPResponse(sock,1) -resp.begin() +resp._begin() print resp.read() resp.close() @@ -24,7 +24,7 @@ body = "HTTP/1.1 400.100 Not Ok\r\n\r\nText" sock = FakeSocket(body) resp = httplib.HTTPResponse(sock,1) try: - resp.begin() + resp._begin() except httplib.BadStatusLine: print "BadStatusLine raised as expected" else: diff --git a/Lib/urlparse.py b/Lib/urlparse.py index 1df83d68d31b..ee99645d59b7 100644 --- a/Lib/urlparse.py +++ b/Lib/urlparse.py @@ -43,19 +43,42 @@ def clear_cache(): _parse_cache = {} -def urlparse(url, scheme = '', allow_fragments = 1): +def urlparse(url, scheme='', allow_fragments=1): """Parse a URL into 6 components: :///;?# Return a 6-tuple: (scheme, netloc, path, params, query, fragment). Note that we don't break the components up in smaller bits (e.g. netloc is a single string) and we don't expand % escapes.""" + tuple = urlsplit(url, scheme, allow_fragments) + scheme, netloc, url, query, fragment = tuple + if scheme in uses_params and ';' in url: + url, params = _splitparams(url) + else: + params = '' + return scheme, netloc, url, params, query, fragment + +def _splitparams(url): + if '/' in url: + i = url.find(';', url.rfind('/')) + if i < 0: + return url, '' + else: + i = url.find(';') + return url[:i], url[i+1:] + +def urlsplit(url, scheme='', allow_fragments=1): + """Parse a URL into 5 components: + :///?# + Return a 5-tuple: (scheme, netloc, path, query, fragment). + Note that we don't break the components up in smaller bits + (e.g. netloc is a single string) and we don't expand % escapes.""" key = url, scheme, allow_fragments cached = _parse_cache.get(key, None) if cached: return cached if len(_parse_cache) >= MAX_CACHE_SIZE: # avoid runaway growth clear_cache() - netloc = path = params = query = fragment = '' + netloc = query = fragment = '' i = url.find(':') if i > 0: if url[:i] == 'http': # optimize the common case @@ -64,23 +87,16 @@ def urlparse(url, scheme = '', allow_fragments = 1): if url[:2] == '//': i = url.find('/', 2) if i < 0: - i = len(url) + i = url.find('#') + if i < 0: + i = len(url) netloc = url[2:i] url = url[i:] - if allow_fragments: - i = url.rfind('#') - if i >= 0: - fragment = url[i+1:] - url = url[:i] - i = url.find('?') - if i >= 0: - query = url[i+1:] - url = url[:i] - i = url.find(';') - if i >= 0: - params = url[i+1:] - url = url[:i] - tuple = scheme, netloc, url, params, query, fragment + if allow_fragments and '#' in url: + url, fragment = url.split('#', 1) + if '?' in url: + url, query = url.split('?', 1) + tuple = scheme, netloc, url, query, fragment _parse_cache[key] = tuple return tuple for c in url[:i]: @@ -94,19 +110,11 @@ def urlparse(url, scheme = '', allow_fragments = 1): if i < 0: i = len(url) netloc, url = url[2:i], url[i:] - if allow_fragments and scheme in uses_fragment: - i = url.rfind('#') - if i >= 0: - url, fragment = url[:i], url[i+1:] - if scheme in uses_query: - i = url.find('?') - if i >= 0: - url, query = url[:i], url[i+1:] - if scheme in uses_params: - i = url.find(';') - if i >= 0: - url, params = url[:i], url[i+1:] - tuple = scheme, netloc, url, params, query, fragment + if allow_fragments and scheme in uses_fragment and '#' in url: + url, fragment = url.split('#', 1) + if scheme in uses_query and '?' in url: + url, query = url.split('?', 1) + tuple = scheme, netloc, url, query, fragment _parse_cache[key] = tuple return tuple @@ -115,13 +123,16 @@ def urlunparse((scheme, netloc, url, params, query, fragment)): slightly different, but equivalent URL, if the URL that was parsed originally had redundant delimiters, e.g. a ? with an empty query (the draft states that these are equivalent).""" + if params: + url = "%s;%s" % (url, params) + return urlunsplit((scheme, netloc, url, query, fragment)) + +def urlunsplit((scheme, netloc, url, query, fragment)): if netloc or (scheme in uses_netloc and url[:2] == '//'): if url and url[:1] != '/': url = '/' + url url = '//' + (netloc or '') + url if scheme: url = scheme + ':' + url - if params: - url = url + ';' + params if query: url = url + '?' + query if fragment: @@ -187,9 +198,12 @@ def urldefrag(url): the URL contained no fragments, the second element is the empty string. """ - s, n, p, a, q, frag = urlparse(url) - defrag = urlunparse((s, n, p, a, q, '')) - return defrag, frag + if '#' in url: + s, n, p, a, q, frag = urlparse(url) + defrag = urlunparse((s, n, p, a, q, '')) + return defrag, frag + else: + return url, '' test_input = """ -- 2.47.3