]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Backport various bug fixes from trunk.
authorJeremy Hylton <jeremy@alum.mit.edu>
Tue, 2 Jul 2002 20:42:50 +0000 (20:42 +0000)
committerJeremy Hylton <jeremy@alum.mit.edu>
Tue, 2 Jul 2002 20:42:50 +0000 (20:42 +0000)
The 2.1 maintenance branch is now identical to the trunk through rev
1.54 of httplib.py.

Lib/httplib.py
Lib/test/test_httplib.py
Lib/urlparse.py

index db704b39f396af68f744d2cf2ab683e86f4ac8d1..1540ba724905f1a1cfc1db1089b9166f1ce029bb 100644 (file)
@@ -66,8 +66,10 @@ Req-started-unread-response    _CS_REQ_STARTED    <response_class>
 Req-sent-unread-response       _CS_REQ_SENT       <response_class>
 """
 
-import socket
+import errno
 import mimetools
+import socket
+from urlparse import urlsplit
 
 try:
     from cStringIO import StringIO
@@ -76,10 +78,10 @@ except ImportError:
 
 __all__ = ["HTTP", "HTTPResponse", "HTTPConnection", "HTTPSConnection",
            "HTTPException", "NotConnected", "UnknownProtocol",
-           "UnknownTransferEncoding", "IllegalKeywordArgument",
-           "UnimplementedFileMode", "IncompleteRead",
-           "ImproperConnectionState", "CannotSendRequest", "CannotSendHeader",
-           "ResponseNotReady", "BadStatusLine", "error"]
+           "UnknownTransferEncoding", "UnimplementedFileMode",
+           "IncompleteRead", "InvalidURL", "ImproperConnectionState",
+           "CannotSendRequest", "CannotSendHeader", "ResponseNotReady",
+           "BadStatusLine", "error"]
 
 HTTP_PORT = 80
 HTTPS_PORT = 443
@@ -109,11 +111,7 @@ class HTTPResponse:
         self.length = _UNKNOWN          # number of bytes left in response
         self.will_close = _UNKNOWN      # conn will close at end of response
 
-    def begin(self):
-        if self.msg is not None:
-            # we've already started reading the response
-            return
-
+    def _read_status(self):
         line = self.fp.readline()
         if self.debuglevel > 0:
             print "reply:", repr(line)
@@ -133,13 +131,33 @@ class HTTPResponse:
 
         # The status code is a three-digit number
         try:
-            self.status = status = int(status)
+            status = int(status)
             if status < 100 or status > 999:
                 raise BadStatusLine(line)
         except ValueError:
             raise BadStatusLine(line)
-        self.reason = reason.strip()
+        return version, status, reason
 
+    def _begin(self):
+        if self.msg is not None:
+            # we've already started reading the response
+            return
+
+        # read until we get a non-100 response
+        while 1:
+            version, status, reason = self._read_status()
+            if status != 100:
+                break
+            # skip the header from the 100 response
+            while 1:
+                skip = self.fp.readline().strip()
+                if not skip:
+                    break
+                if self.debuglevel > 0:
+                    print "header:", skip
+            
+        self.status = status
+        self.reason = reason.strip()
         if version == 'HTTP/1.0':
             self.version = 10
         elif version.startswith('HTTP/1.'):
@@ -150,6 +168,7 @@ class HTTPResponse:
             raise UnknownProtocol(version)
 
         if self.version == 9:
+            self.chunked = 0
             self.msg = mimetools.Message(StringIO())
             return
 
@@ -231,6 +250,7 @@ class HTTPResponse:
             return ''
 
         if self.chunked:
+            assert self.chunked != _UNKNOWN
             chunk_left = self.chunk_left
             value = ''
             while 1:
@@ -345,7 +365,10 @@ class HTTPConnection:
         if port is None:
             i = host.find(':')
             if i >= 0:
-                port = int(host[i+1:])
+                try:
+                    port = int(host[i+1:])
+                except ValueError:
+                    raise InvalidURL("nonnumeric port: '%s'" % host[i+1:])
                 host = host[:i]
             else:
                 port = self.default_port
@@ -394,7 +417,7 @@ class HTTPConnection:
                 self.close()
             raise
 
-    def putrequest(self, method, url):
+    def putrequest(self, method, url, skip_host=0):
         """Send a request to the server.
 
         `method' specifies an HTTP request method, e.g. 'GET'.
@@ -445,18 +468,31 @@ class HTTPConnection:
         if self._http_vsn == 11:
             # Issue some standard headers for better HTTP/1.1 compliance
 
-            # this header is issued *only* for HTTP/1.1 connections. more
-            # specifically, this means it is only issued when the client uses
-            # the new HTTPConnection() class. backwards-compat clients will
-            # be using HTTP/1.0 and those clients may be issuing this header
-            # themselves. we should NOT issue it twice; some web servers (such
-            # as Apache) barf when they see two Host: headers
-
-            # if we need a non-standard port,include it in the header
-            if self.port == HTTP_PORT:
-                self.putheader('Host', self.host)
-            else:
-                self.putheader('Host', "%s:%s" % (self.host, self.port))
+            if not skip_host:
+                # this header is issued *only* for HTTP/1.1
+                # connections. more specifically, this means it is
+                # only issued when the client uses the new
+                # HTTPConnection() class. backwards-compat clients
+                # will be using HTTP/1.0 and those clients may be
+                # issuing this header themselves. we should NOT issue
+                # it twice; some web servers (such as Apache) barf
+                # when they see two Host: headers
+
+                # If we need a non-standard port,include it in the
+                # header.  If the request is going through a proxy,
+                # but the host of the actual URL, not the host of the
+                # proxy.
+
+                netloc = ''
+                if url.startswith('http'):
+                    nil, netloc, nil, nil, nil = urlsplit(url)
+
+                if netloc:
+                    self.putheader('Host', netloc)
+                elif self.port == HTTP_PORT:
+                    self.putheader('Host', self.host)
+                else:
+                    self.putheader('Host', "%s:%s" % (self.host, self.port))
 
             # note: we are assuming that clients will not attempt to set these
             #       headers since *this* library must deal with the
@@ -514,7 +550,14 @@ class HTTPConnection:
             self._send_request(method, url, body, headers)
 
     def _send_request(self, method, url, body, headers):
-        self.putrequest(method, url)
+        # If headers already contains a host header, then define the
+        # optional skip_host argument to putrequest().  The check is
+        # harder because field names are case insensitive.
+        if 'Host' in (headers
+            or [k for k in headers.iterkeys() if k.lower() == "host"]):
+            self.putrequest(method, url, skip_host=1)
+        else:
+            self.putrequest(method, url)
 
         if body:
             self.putheader('Content-Length', str(len(body)))
@@ -556,7 +599,8 @@ class HTTPConnection:
         else:
             response = self.response_class(self.sock)
 
-        response.begin()
+        response._begin()
+        assert response.will_close != _UNKNOWN
         self.__state = _CS_IDLE
 
         if response.will_close:
@@ -568,6 +612,83 @@ class HTTPConnection:
 
         return response
 
+class SSLFile:
+    """File-like object wrapping an SSL socket."""
+
+    BUFSIZE = 8192
+    
+    def __init__(self, sock, ssl, bufsize=None):
+        self._sock = sock
+        self._ssl = ssl
+        self._buf = ''
+        self._bufsize = bufsize or self.__class__.BUFSIZE
+
+    def _read(self):
+        buf = ''
+        # put in a loop so that we retry on transient errors
+        while 1:
+            try:
+                buf = self._ssl.read(self._bufsize)
+            except socket.sslerror, err:
+                if (err[0] == socket.SSL_ERROR_WANT_READ
+                    or err[0] == socket.SSL_ERROR_WANT_WRITE):
+                    continue
+                if (err[0] == socket.SSL_ERROR_ZERO_RETURN
+                    or err[0] == socket.SSL_ERROR_EOF):
+                    break
+                raise
+            except socket.error, err:
+                if err[0] == errno.EINTR:
+                    continue
+                if err[0] == errno.EBADF:
+                    # XXX socket was closed?
+                    break
+                raise
+            else:
+                break
+        return buf
+
+    def read(self, size=None):
+        L = [self._buf]
+        avail = len(self._buf)
+        while size is None or avail < size:
+            s = self._read()
+            if s == '':
+                break
+            L.append(s)
+            avail += len(s)
+        all = "".join(L)
+        if size is None:
+            self._buf = ''
+            return all
+        else:
+            self._buf = all[size:]
+            return all[:size]
+
+    def readline(self):
+        L = [self._buf]
+        self._buf = ''
+        while 1:
+            i = L[-1].find("\n")
+            if i >= 0:
+                break
+            s = self._read()
+            if s == '':
+                break
+            L.append(s)
+        if i == -1:
+            # loop exited because there is no more data
+            return "".join(L)
+        else:
+            all = "".join(L)
+            # XXX could do enough bookkeeping not to do a 2nd search
+            i = all.find("\n") + 1
+            line = all[:i]
+            self._buf = all[i:]
+            return line
+
+    def close(self):
+        self._sock.close()
 
 class FakeSocket:
     def __init__(self, sock, ssl):
@@ -575,27 +696,9 @@ class FakeSocket:
         self.__ssl = ssl
 
     def makefile(self, mode, bufsize=None):
-        """Return a readable file-like object with data from socket.
-
-        This method offers only partial support for the makefile
-        interface of a real socket.  It only supports modes 'r' and
-        'rb' and the bufsize argument is ignored.
-
-        The returned object contains *all* of the file data
-        """
         if mode != 'r' and mode != 'rb':
             raise UnimplementedFileMode()
-
-        msgbuf = []
-        while 1:
-            try:
-                buf = self.__ssl.read()
-            except socket.sslerror, msg:
-                break
-            if buf == '':
-                break
-            msgbuf.append(buf)
-        return StringIO("".join(msgbuf))
+        return SSLFile(self.__sock, self.__ssl, bufsize)
 
     def send(self, stuff, flags = 0):
         return self.__ssl.write(stuff)
@@ -615,21 +718,10 @@ class HTTPSConnection(HTTPConnection):
 
     default_port = HTTPS_PORT
 
-    def __init__(self, host, port=None, **x509):
-        keys = x509.keys()
-        try:
-            keys.remove('key_file')
-        except ValueError:
-            pass
-        try:
-            keys.remove('cert_file')
-        except ValueError:
-            pass
-        if keys:
-            raise IllegalKeywordArgument()
+    def __init__(self, host, port=None, key_file=None, cert_file=None):
         HTTPConnection.__init__(self, host, port)
-        self.key_file = x509.get('key_file')
-        self.cert_file = x509.get('cert_file')
+        self.key_file = key_file
+        self.cert_file = cert_file
 
     def connect(self):
         "Connect to a host on a given (SSL) port."
@@ -653,7 +745,7 @@ class HTTP:
 
     _connection_class = HTTPConnection
 
-    def __init__(self, host='', port=None, **x509):
+    def __init__(self, host='', port=None):
         "Provide a default host, since the superclass requires one."
 
         # some joker passed 0 explicitly, meaning default port
@@ -663,18 +755,19 @@ class HTTP:
         # Note that we may pass an empty string as the host; this will throw
         # an error when we attempt to connect. Presumably, the client code
         # will call connect before then, with a proper host.
-        self._conn = self._connection_class(host, port)
+        self._setup(self._connection_class(host, port))
+
+    def _setup(self, conn):
+        self._conn = conn
+
         # set up delegation to flesh out interface
-        self.send = self._conn.send
-        self.putrequest = self._conn.putrequest
-        self.endheaders = self._conn.endheaders
-        self._conn._http_vsn = self._http_vsn
-        self._conn._http_vsn_str = self._http_vsn_str
+        self.send = conn.send
+        self.putrequest = conn.putrequest
+        self.endheaders = conn.endheaders
+        self.set_debuglevel = conn.set_debuglevel
 
-        # we never actually use these for anything, but we keep them here for
-        # compatibility with post-1.5.2 CVS.
-        self.key_file = x509.get('key_file')
-        self.cert_file = x509.get('cert_file')
+        conn._http_vsn = self._http_vsn
+        conn._http_vsn_str = self._http_vsn_str
 
         self.file = None
 
@@ -685,9 +778,6 @@ class HTTP:
             self._conn._set_hostport(host, port)
         self._conn.connect()
 
-    def set_debuglevel(self, debuglevel):
-        self._conn.set_debuglevel(debuglevel)
-
     def getfile(self):
         "Provide a getfile, since the superclass' does not use this concept."
         return self.file
@@ -745,6 +835,19 @@ if hasattr(socket, 'ssl'):
 
         _connection_class = HTTPSConnection
 
+        def __init__(self, host='', port=None, **x509):
+            # provide a default host, pass the X509 cert info
+
+            # urf. compensate for bad input.
+            if port == 0:
+                port = None
+            self._setup(self._connection_class(host, port, **x509))
+
+            # we never actually use these for anything, but we keep them
+            # here for compatibility with post-1.5.2 CVS.
+            self.key_file = x509.get('key_file')
+            self.cert_file = x509.get('cert_file')
+
 
 class HTTPException(Exception):
     pass
@@ -752,6 +855,9 @@ class HTTPException(Exception):
 class NotConnected(HTTPException):
     pass
 
+class InvalidURL(HTTPException):
+    pass
+
 class UnknownProtocol(HTTPException):
     def __init__(self, version):
         self.version = version
@@ -759,9 +865,6 @@ class UnknownProtocol(HTTPException):
 class UnknownTransferEncoding(HTTPException):
     pass
 
-class IllegalKeywordArgument(HTTPException):
-    pass
-
 class UnimplementedFileMode(HTTPException):
     pass
 
@@ -822,7 +925,18 @@ def test():
     if headers:
         for header in headers.headers: print header.strip()
     print
-    print h.getfile().read()
+    print "read", len(h.getfile().read())
+
+    # minimal test that code to extract host from url works
+    class HTTP11(HTTP):
+        _http_vsn = 11
+        _http_vsn_str = 'HTTP/1.1'
+
+    h = HTTP11('www.python.org')
+    h.putrequest('GET', 'http://www.python.org/~jeremy/')
+    h.endheaders()
+    h.getreply()
+    h.close()
 
     if hasattr(socket, 'ssl'):
         host = 'sourceforge.net'
@@ -832,13 +946,14 @@ def test():
         hs.putrequest('GET', selector)
         hs.endheaders()
         status, reason, headers = hs.getreply()
+        # XXX why does this give a 302 response?
         print 'status =', status
         print 'reason =', reason
         print
         if headers:
             for header in headers.headers: print header.strip()
         print
-        print hs.getfile().read()
+        print "read", len(hs.getfile().read())
 
 
 if __name__ == '__main__':
index aef65a68112e399a96a38ffedd54bb43e9e1d728..bc49d62468eab220fef66b837862c8512b3c5cf5 100644 (file)
@@ -16,7 +16,7 @@ class FakeSocket:
 body = "HTTP/1.1 200 Ok\r\n\r\nText"
 sock = FakeSocket(body)
 resp = httplib.HTTPResponse(sock,1)
-resp.begin()
+resp._begin()
 print resp.read()
 resp.close()
 
@@ -24,7 +24,7 @@ body = "HTTP/1.1 400.100 Not Ok\r\n\r\nText"
 sock = FakeSocket(body)
 resp = httplib.HTTPResponse(sock,1)
 try:
-    resp.begin()
+    resp._begin()
 except httplib.BadStatusLine:
     print "BadStatusLine raised as expected"
 else:
index 1df83d68d31b13b895c069657e149271cbf2fcac..ee99645d59b7ce3937021c203a6c8b3fd16c5953 100644 (file)
@@ -43,19 +43,42 @@ def clear_cache():
     _parse_cache = {}
 
 
-def urlparse(url, scheme = '', allow_fragments = 1):
+def urlparse(url, scheme='', allow_fragments=1):
     """Parse a URL into 6 components:
     <scheme>://<netloc>/<path>;<params>?<query>#<fragment>
     Return a 6-tuple: (scheme, netloc, path, params, query, fragment).
     Note that we don't break the components up in smaller bits
     (e.g. netloc is a single string) and we don't expand % escapes."""
+    tuple = urlsplit(url, scheme, allow_fragments)
+    scheme, netloc, url, query, fragment = tuple
+    if scheme in uses_params and ';' in url:
+        url, params = _splitparams(url)
+    else:
+        params = ''
+    return scheme, netloc, url, params, query, fragment
+
+def _splitparams(url):
+    if '/'  in url:
+        i = url.find(';', url.rfind('/'))
+        if i < 0:
+            return url, ''
+    else:
+        i = url.find(';')
+    return url[:i], url[i+1:]
+
+def urlsplit(url, scheme='', allow_fragments=1):
+    """Parse a URL into 5 components:
+    <scheme>://<netloc>/<path>?<query>#<fragment>
+    Return a 5-tuple: (scheme, netloc, path, query, fragment).
+    Note that we don't break the components up in smaller bits
+    (e.g. netloc is a single string) and we don't expand % escapes."""
     key = url, scheme, allow_fragments
     cached = _parse_cache.get(key, None)
     if cached:
         return cached
     if len(_parse_cache) >= MAX_CACHE_SIZE: # avoid runaway growth
         clear_cache()
-    netloc = path = params = query = fragment = ''
+    netloc = query = fragment = ''
     i = url.find(':')
     if i > 0:
         if url[:i] == 'http': # optimize the common case
@@ -64,23 +87,16 @@ def urlparse(url, scheme = '', allow_fragments = 1):
             if url[:2] == '//':
                 i = url.find('/', 2)
                 if i < 0:
-                    i = len(url)
+                    i = url.find('#')
+                    if i < 0:
+                        i = len(url)
                 netloc = url[2:i]
                 url = url[i:]
-            if allow_fragments:
-                i = url.rfind('#')
-                if i >= 0:
-                    fragment = url[i+1:]
-                    url = url[:i]
-            i = url.find('?')
-            if i >= 0:
-                query = url[i+1:]
-                url = url[:i]
-            i = url.find(';')
-            if i >= 0:
-                params = url[i+1:]
-                url = url[:i]
-            tuple = scheme, netloc, url, params, query, fragment
+            if allow_fragments and '#' in url:
+                url, fragment = url.split('#', 1)
+            if '?' in url:
+                url, query = url.split('?', 1)
+            tuple = scheme, netloc, url, query, fragment
             _parse_cache[key] = tuple
             return tuple
         for c in url[:i]:
@@ -94,19 +110,11 @@ def urlparse(url, scheme = '', allow_fragments = 1):
             if i < 0:
                 i = len(url)
             netloc, url = url[2:i], url[i:]
-    if allow_fragments and scheme in uses_fragment:
-        i = url.rfind('#')
-        if i >= 0:
-            url, fragment = url[:i], url[i+1:]
-    if scheme in uses_query:
-        i = url.find('?')
-        if i >= 0:
-            url, query = url[:i], url[i+1:]
-    if scheme in uses_params:
-        i = url.find(';')
-        if i >= 0:
-            url, params = url[:i], url[i+1:]
-    tuple = scheme, netloc, url, params, query, fragment
+    if allow_fragments and scheme in uses_fragment and '#' in url:
+        url, fragment = url.split('#', 1)
+    if scheme in uses_query and '?' in url:
+        url, query = url.split('?', 1)
+    tuple = scheme, netloc, url, query, fragment
     _parse_cache[key] = tuple
     return tuple
 
@@ -115,13 +123,16 @@ def urlunparse((scheme, netloc, url, params, query, fragment)):
     slightly different, but equivalent URL, if the URL that was parsed
     originally had redundant delimiters, e.g. a ? with an empty query
     (the draft states that these are equivalent)."""
+    if params:
+        url = "%s;%s" % (url, params)
+    return urlunsplit((scheme, netloc, url, query, fragment))
+
+def urlunsplit((scheme, netloc, url, query, fragment)):
     if netloc or (scheme in uses_netloc and url[:2] == '//'):
         if url and url[:1] != '/': url = '/' + url
         url = '//' + (netloc or '') + url
     if scheme:
         url = scheme + ':' + url
-    if params:
-        url = url + ';' + params
     if query:
         url = url + '?' + query
     if fragment:
@@ -187,9 +198,12 @@ def urldefrag(url):
     the URL contained no fragments, the second element is the
     empty string.
     """
-    s, n, p, a, q, frag = urlparse(url)
-    defrag = urlunparse((s, n, p, a, q, ''))
-    return defrag, frag
+    if '#' in url:
+        s, n, p, a, q, frag = urlparse(url)
+        defrag = urlunparse((s, n, p, a, q, ''))
+        return defrag, frag
+    else:
+        return url, ''
 
 
 test_input = """