]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Backport various bug fixes from trunk.
authorJeremy Hylton <jeremy@alum.mit.edu>
Tue, 2 Jul 2002 17:19:47 +0000 (17:19 +0000)
committerJeremy Hylton <jeremy@alum.mit.edu>
Tue, 2 Jul 2002 17:19:47 +0000 (17:19 +0000)
The 2.2 maintenace branch is now identical to the trunk through rev
1.53.

Lib/httplib.py

index cbe6e8ffee589cf8b5f36a12919777b2460c2db0..5a039e7a1c07faa5389423ecebbe3a3dc7346df1 100644 (file)
@@ -78,10 +78,10 @@ except ImportError:
 
 __all__ = ["HTTP", "HTTPResponse", "HTTPConnection", "HTTPSConnection",
            "HTTPException", "NotConnected", "UnknownProtocol",
-           "UnknownTransferEncoding", "IllegalKeywordArgument",
-           "UnimplementedFileMode", "IncompleteRead",
-           "ImproperConnectionState", "CannotSendRequest", "CannotSendHeader",
-           "ResponseNotReady", "BadStatusLine", "error"]
+           "UnknownTransferEncoding", "UnimplementedFileMode",
+           "IncompleteRead", "InvalidURL", "ImproperConnectionState",
+           "CannotSendRequest", "CannotSendHeader", "ResponseNotReady",
+           "BadStatusLine", "error"]
 
 HTTP_PORT = 80
 HTTPS_PORT = 443
@@ -111,11 +111,7 @@ class HTTPResponse:
         self.length = _UNKNOWN          # number of bytes left in response
         self.will_close = _UNKNOWN      # conn will close at end of response
 
-    def begin(self):
-        if self.msg is not None:
-            # we've already started reading the response
-            return
-
+    def _read_status(self):
         line = self.fp.readline()
         if self.debuglevel > 0:
             print "reply:", repr(line)
@@ -135,13 +131,33 @@ class HTTPResponse:
 
         # The status code is a three-digit number
         try:
-            self.status = status = int(status)
+            status = int(status)
             if status < 100 or status > 999:
                 raise BadStatusLine(line)
         except ValueError:
             raise BadStatusLine(line)
-        self.reason = reason.strip()
+        return version, status, reason
 
+    def _begin(self):
+        if self.msg is not None:
+            # we've already started reading the response
+            return
+
+        # read until we get a non-100 response
+        while 1:
+            version, status, reason = self._read_status()
+            if status != 100:
+                break
+            # skip the header from the 100 response
+            while 1:
+                skip = self.fp.readline().strip()
+                if not skip:
+                    break
+                if self.debuglevel > 0:
+                    print "header:", skip
+            
+        self.status = status
+        self.reason = reason.strip()
         if version == 'HTTP/1.0':
             self.version = 10
         elif version.startswith('HTTP/1.'):
@@ -152,6 +168,7 @@ class HTTPResponse:
             raise UnknownProtocol(version)
 
         if self.version == 9:
+            self.chunked = 0
             self.msg = mimetools.Message(StringIO())
             return
 
@@ -233,6 +250,7 @@ class HTTPResponse:
             return ''
 
         if self.chunked:
+            assert self.chunked != _UNKNOWN
             chunk_left = self.chunk_left
             value = ''
             while 1:
@@ -347,7 +365,10 @@ class HTTPConnection:
         if port is None:
             i = host.find(':')
             if i >= 0:
-                port = int(host[i+1:])
+                try:
+                    port = int(host[i+1:])
+                except ValueError:
+                    raise InvalidURL, "nonnumeric port: '%s'"%host[i+1:]
                 host = host[:i]
             else:
                 port = self.default_port
@@ -360,7 +381,8 @@ class HTTPConnection:
     def connect(self):
         """Connect to the host and port specified in __init__."""
         msg = "getaddrinfo returns an empty list"
-        for res in socket.getaddrinfo(self.host, self.port, 0, socket.SOCK_STREAM):
+        for res in socket.getaddrinfo(self.host, self.port, 0,
+                                      socket.SOCK_STREAM):
             af, socktype, proto, canonname, sa = res
             try:
                 self.sock = socket.socket(af, socktype, proto)
@@ -546,7 +568,7 @@ class HTTPConnection:
         # If headers already contains a host header, then define the
         # optional skip_host argument to putrequest().  The check is
         # harder because field names are case insensitive.
-        if (headers.has_key('Host')
+        if 'Host' in (headers
             or [k for k in headers.iterkeys() if k.lower() == "host"]):
             self.putrequest(method, url, skip_host=1)
         else:
@@ -592,7 +614,8 @@ class HTTPConnection:
         else:
             response = self.response_class(self.sock)
 
-        response.begin()
+        response._begin()
+        assert response.will_close != _UNKNOWN
         self.__state = _CS_IDLE
 
         if response.will_close:
@@ -604,45 +627,93 @@ class HTTPConnection:
 
         return response
 
+class SSLFile:
+    """File-like object wrapping an SSL socket."""
 
-class FakeSocket:
-    def __init__(self, sock, ssl):
-        self.__sock = sock
-        self.__ssl = ssl
-
-    def makefile(self, mode, bufsize=None):
-        """Return a readable file-like object with data from socket.
-
-        This method offers only partial support for the makefile
-        interface of a real socket.  It only supports modes 'r' and
-        'rb' and the bufsize argument is ignored.
+    BUFSIZE = 8192
+    
+    def __init__(self, sock, ssl, bufsize=None):
+        self._sock = sock
+        self._ssl = ssl
+        self._buf = ''
+        self._bufsize = bufsize or self.__class__.BUFSIZE
 
-        The returned object contains *all* of the file data
-        """
-        if mode != 'r' and mode != 'rb':
-            raise UnimplementedFileMode()
-
-        msgbuf = []
+    def _read(self):
+        buf = ''
+        # put in a loop so that we retry on transient errors
         while 1:
             try:
-                buf = self.__ssl.read()
+                buf = self._ssl.read(self._bufsize)
             except socket.sslerror, err:
                 if (err[0] == socket.SSL_ERROR_WANT_READ
-                    or err[0] == socket.SSL_ERROR_WANT_WRITE
-                    or 0):
+                    or err[0] == socket.SSL_ERROR_WANT_WRITE):
                     continue
-                if (err[0] == socket.SSL_ERROR_ZERO_RETURN 
-                       or err[0] == socket.SSL_ERROR_EOF):
+                if (err[0] == socket.SSL_ERROR_ZERO_RETURN
+                    or err[0] == socket.SSL_ERROR_EOF):
                     break
                 raise
             except socket.error, err:
                 if err[0] == errno.EINTR:
                     continue
+                if err[0] == errno.EBADF:
+                    # XXX socket was closed?
+                    break
                 raise
-            if buf == '':
+            else:
                 break
-            msgbuf.append(buf)
-        return StringIO("".join(msgbuf))
+        return buf
+
+    def read(self, size=None):
+        L = [self._buf]
+        avail = len(self._buf)
+        while size is None or avail < size:
+            s = self._read()
+            if s == '':
+                break
+            L.append(s)
+            avail += len(s)
+        all = "".join(L)
+        if size is None:
+            self._buf = ''
+            return all
+        else:
+            self._buf = all[size:]
+            return all[:size]
+
+    def readline(self):
+        L = [self._buf]
+        self._buf = ''
+        while 1:
+            i = L[-1].find("\n")
+            if i >= 0:
+                break
+            s = self._read()
+            if s == '':
+                break
+            L.append(s)
+        if i == -1:
+            # loop exited because there is no more data
+            return "".join(L)
+        else:
+            all = "".join(L)
+            # XXX could do enough bookkeeping not to do a 2nd search
+            i = all.find("\n") + 1
+            line = all[:i]
+            self._buf = all[i:]
+            return line
+
+    def close(self):
+        self._sock.close()
+
+class FakeSocket:
+    def __init__(self, sock, ssl):
+        self.__sock = sock
+        self.__ssl = ssl
+
+    def makefile(self, mode, bufsize=None):
+        if mode != 'r' and mode != 'rb':
+            raise UnimplementedFileMode()
+        return SSLFile(self.__sock, self.__ssl, bufsize)
 
     def send(self, stuff, flags = 0):
         return self.__ssl.write(stuff)
@@ -662,21 +733,10 @@ class HTTPSConnection(HTTPConnection):
 
     default_port = HTTPS_PORT
 
-    def __init__(self, host, port=None, **x509):
-        keys = x509.keys()
-        try:
-            keys.remove('key_file')
-        except ValueError:
-            pass
-        try:
-            keys.remove('cert_file')
-        except ValueError:
-            pass
-        if keys:
-            raise IllegalKeywordArgument()
+    def __init__(self, host, port=None, key_file=None, cert_file=None):
         HTTPConnection.__init__(self, host, port)
-        self.key_file = x509.get('key_file')
-        self.cert_file = x509.get('cert_file')
+        self.key_file = key_file
+        self.cert_file = cert_file
 
     def connect(self):
         "Connect to a host on a given (SSL) port."
@@ -810,6 +870,9 @@ class HTTPException(Exception):
 class NotConnected(HTTPException):
     pass
 
+class InvalidURL(HTTPException):
+    pass
+
 class UnknownProtocol(HTTPException):
     def __init__(self, version):
         self.version = version
@@ -817,9 +880,6 @@ class UnknownProtocol(HTTPException):
 class UnknownTransferEncoding(HTTPException):
     pass
 
-class IllegalKeywordArgument(HTTPException):
-    pass
-
 class UnimplementedFileMode(HTTPException):
     pass
 
@@ -880,7 +940,7 @@ def test():
     if headers:
         for header in headers.headers: print header.strip()
     print
-    print h.getfile().read()
+    print "read", len(h.getfile().read())
 
     # minimal test that code to extract host from url works
     class HTTP11(HTTP):
@@ -901,13 +961,14 @@ def test():
         hs.putrequest('GET', selector)
         hs.endheaders()
         status, reason, headers = hs.getreply()
+        # XXX why does this give a 302 response?
         print 'status =', status
         print 'reason =', reason
         print
         if headers:
             for header in headers.headers: print header.strip()
         print
-        print hs.getfile().read()
+        print "read", len(hs.getfile().read())
 
 
 if __name__ == '__main__':