From: Jeremy Hylton Date: Tue, 2 Jul 2002 17:19:47 +0000 (+0000) Subject: Backport various bug fixes from trunk. X-Git-Tag: v2.2.2b1~286 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=5b4f1c12011d21b55dbe0bd07f30c49d0a197473;p=thirdparty%2FPython%2Fcpython.git Backport various bug fixes from trunk. The 2.2 maintenace branch is now identical to the trunk through rev 1.53. --- diff --git a/Lib/httplib.py b/Lib/httplib.py index cbe6e8ffee58..5a039e7a1c07 100644 --- a/Lib/httplib.py +++ b/Lib/httplib.py @@ -78,10 +78,10 @@ except ImportError: __all__ = ["HTTP", "HTTPResponse", "HTTPConnection", "HTTPSConnection", "HTTPException", "NotConnected", "UnknownProtocol", - "UnknownTransferEncoding", "IllegalKeywordArgument", - "UnimplementedFileMode", "IncompleteRead", - "ImproperConnectionState", "CannotSendRequest", "CannotSendHeader", - "ResponseNotReady", "BadStatusLine", "error"] + "UnknownTransferEncoding", "UnimplementedFileMode", + "IncompleteRead", "InvalidURL", "ImproperConnectionState", + "CannotSendRequest", "CannotSendHeader", "ResponseNotReady", + "BadStatusLine", "error"] HTTP_PORT = 80 HTTPS_PORT = 443 @@ -111,11 +111,7 @@ class HTTPResponse: self.length = _UNKNOWN # number of bytes left in response self.will_close = _UNKNOWN # conn will close at end of response - def begin(self): - if self.msg is not None: - # we've already started reading the response - return - + def _read_status(self): line = self.fp.readline() if self.debuglevel > 0: print "reply:", repr(line) @@ -135,13 +131,33 @@ class HTTPResponse: # The status code is a three-digit number try: - self.status = status = int(status) + status = int(status) if status < 100 or status > 999: raise BadStatusLine(line) except ValueError: raise BadStatusLine(line) - self.reason = reason.strip() + return version, status, reason + def _begin(self): + if self.msg is not None: + # we've already started reading the response + return + + # read until we get a non-100 response + while 1: + version, status, reason = self._read_status() + if status != 100: + break + # skip the header from the 100 response + while 1: + skip = self.fp.readline().strip() + if not skip: + break + if self.debuglevel > 0: + print "header:", skip + + self.status = status + self.reason = reason.strip() if version == 'HTTP/1.0': self.version = 10 elif version.startswith('HTTP/1.'): @@ -152,6 +168,7 @@ class HTTPResponse: raise UnknownProtocol(version) if self.version == 9: + self.chunked = 0 self.msg = mimetools.Message(StringIO()) return @@ -233,6 +250,7 @@ class HTTPResponse: return '' if self.chunked: + assert self.chunked != _UNKNOWN chunk_left = self.chunk_left value = '' while 1: @@ -347,7 +365,10 @@ class HTTPConnection: if port is None: i = host.find(':') if i >= 0: - port = int(host[i+1:]) + try: + port = int(host[i+1:]) + except ValueError: + raise InvalidURL, "nonnumeric port: '%s'"%host[i+1:] host = host[:i] else: port = self.default_port @@ -360,7 +381,8 @@ class HTTPConnection: def connect(self): """Connect to the host and port specified in __init__.""" msg = "getaddrinfo returns an empty list" - for res in socket.getaddrinfo(self.host, self.port, 0, socket.SOCK_STREAM): + for res in socket.getaddrinfo(self.host, self.port, 0, + socket.SOCK_STREAM): af, socktype, proto, canonname, sa = res try: self.sock = socket.socket(af, socktype, proto) @@ -546,7 +568,7 @@ class HTTPConnection: # If headers already contains a host header, then define the # optional skip_host argument to putrequest(). The check is # harder because field names are case insensitive. - if (headers.has_key('Host') + if 'Host' in (headers or [k for k in headers.iterkeys() if k.lower() == "host"]): self.putrequest(method, url, skip_host=1) else: @@ -592,7 +614,8 @@ class HTTPConnection: else: response = self.response_class(self.sock) - response.begin() + response._begin() + assert response.will_close != _UNKNOWN self.__state = _CS_IDLE if response.will_close: @@ -604,45 +627,93 @@ class HTTPConnection: return response +class SSLFile: + """File-like object wrapping an SSL socket.""" -class FakeSocket: - def __init__(self, sock, ssl): - self.__sock = sock - self.__ssl = ssl - - def makefile(self, mode, bufsize=None): - """Return a readable file-like object with data from socket. - - This method offers only partial support for the makefile - interface of a real socket. It only supports modes 'r' and - 'rb' and the bufsize argument is ignored. + BUFSIZE = 8192 + + def __init__(self, sock, ssl, bufsize=None): + self._sock = sock + self._ssl = ssl + self._buf = '' + self._bufsize = bufsize or self.__class__.BUFSIZE - The returned object contains *all* of the file data - """ - if mode != 'r' and mode != 'rb': - raise UnimplementedFileMode() - - msgbuf = [] + def _read(self): + buf = '' + # put in a loop so that we retry on transient errors while 1: try: - buf = self.__ssl.read() + buf = self._ssl.read(self._bufsize) except socket.sslerror, err: if (err[0] == socket.SSL_ERROR_WANT_READ - or err[0] == socket.SSL_ERROR_WANT_WRITE - or 0): + or err[0] == socket.SSL_ERROR_WANT_WRITE): continue - if (err[0] == socket.SSL_ERROR_ZERO_RETURN - or err[0] == socket.SSL_ERROR_EOF): + if (err[0] == socket.SSL_ERROR_ZERO_RETURN + or err[0] == socket.SSL_ERROR_EOF): break raise except socket.error, err: if err[0] == errno.EINTR: continue + if err[0] == errno.EBADF: + # XXX socket was closed? + break raise - if buf == '': + else: break - msgbuf.append(buf) - return StringIO("".join(msgbuf)) + return buf + + def read(self, size=None): + L = [self._buf] + avail = len(self._buf) + while size is None or avail < size: + s = self._read() + if s == '': + break + L.append(s) + avail += len(s) + all = "".join(L) + if size is None: + self._buf = '' + return all + else: + self._buf = all[size:] + return all[:size] + + def readline(self): + L = [self._buf] + self._buf = '' + while 1: + i = L[-1].find("\n") + if i >= 0: + break + s = self._read() + if s == '': + break + L.append(s) + if i == -1: + # loop exited because there is no more data + return "".join(L) + else: + all = "".join(L) + # XXX could do enough bookkeeping not to do a 2nd search + i = all.find("\n") + 1 + line = all[:i] + self._buf = all[i:] + return line + + def close(self): + self._sock.close() + +class FakeSocket: + def __init__(self, sock, ssl): + self.__sock = sock + self.__ssl = ssl + + def makefile(self, mode, bufsize=None): + if mode != 'r' and mode != 'rb': + raise UnimplementedFileMode() + return SSLFile(self.__sock, self.__ssl, bufsize) def send(self, stuff, flags = 0): return self.__ssl.write(stuff) @@ -662,21 +733,10 @@ class HTTPSConnection(HTTPConnection): default_port = HTTPS_PORT - def __init__(self, host, port=None, **x509): - keys = x509.keys() - try: - keys.remove('key_file') - except ValueError: - pass - try: - keys.remove('cert_file') - except ValueError: - pass - if keys: - raise IllegalKeywordArgument() + def __init__(self, host, port=None, key_file=None, cert_file=None): HTTPConnection.__init__(self, host, port) - self.key_file = x509.get('key_file') - self.cert_file = x509.get('cert_file') + self.key_file = key_file + self.cert_file = cert_file def connect(self): "Connect to a host on a given (SSL) port." @@ -810,6 +870,9 @@ class HTTPException(Exception): class NotConnected(HTTPException): pass +class InvalidURL(HTTPException): + pass + class UnknownProtocol(HTTPException): def __init__(self, version): self.version = version @@ -817,9 +880,6 @@ class UnknownProtocol(HTTPException): class UnknownTransferEncoding(HTTPException): pass -class IllegalKeywordArgument(HTTPException): - pass - class UnimplementedFileMode(HTTPException): pass @@ -880,7 +940,7 @@ def test(): if headers: for header in headers.headers: print header.strip() print - print h.getfile().read() + print "read", len(h.getfile().read()) # minimal test that code to extract host from url works class HTTP11(HTTP): @@ -901,13 +961,14 @@ def test(): hs.putrequest('GET', selector) hs.endheaders() status, reason, headers = hs.getreply() + # XXX why does this give a 302 response? print 'status =', status print 'reason =', reason print if headers: for header in headers.headers: print header.strip() print - print hs.getfile().read() + print "read", len(hs.getfile().read()) if __name__ == '__main__':