]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Backport changes.
authorJeremy Hylton <jeremy@alum.mit.edu>
Fri, 12 Jul 2002 14:23:43 +0000 (14:23 +0000)
committerJeremy Hylton <jeremy@alum.mit.edu>
Fri, 12 Jul 2002 14:23:43 +0000 (14:23 +0000)
Change _begin() back to begin().
Fix for SF bug 579107.
Fix for SF bug #432621: httplib: multiple Set-Cookie headers
Fix SF bug #575360
Handle HTTP/0.9 responses.

Lib/httplib.py
Lib/test/test_httplib.py

index 5a039e7a1c07faa5389423ecebbe3a3dc7346df1..34ed2da57f5c508e999b60ddf337f8a397f22402 100644 (file)
@@ -93,11 +93,126 @@ _CS_IDLE = 'Idle'
 _CS_REQ_STARTED = 'Request-started'
 _CS_REQ_SENT = 'Request-sent'
 
+class HTTPMessage(mimetools.Message):
+
+    def addheader(self, key, value):
+        """Add header for field key handling repeats."""
+        prev = self.dict.get(key)
+        if prev is None:
+            self.dict[key] = value
+        else:
+            combined = ", ".join((prev, value))
+            self.dict[key] = combined
+
+    def addcontinue(self, key, more):
+        """Add more field data from a continuation line."""
+        prev = self.dict[key]
+        self.dict[key] = prev + "\n " + more
+
+    def readheaders(self):
+        """Read header lines.
+
+        Read header lines up to the entirely blank line that terminates them.
+        The (normally blank) line that ends the headers is skipped, but not
+        included in the returned list.  If a non-header line ends the headers,
+        (which is an error), an attempt is made to backspace over it; it is
+        never included in the returned list.
+
+        The variable self.status is set to the empty string if all went well,
+        otherwise it is an error message.  The variable self.headers is a
+        completely uninterpreted list of lines contained in the header (so
+        printing them will reproduce the header exactly as it appears in the
+        file).
+
+        If multiple header fields with the same name occur, they are combined
+        according to the rules in RFC 2616 sec 4.2:
+
+        Appending each subsequent field-value to the first, each separated
+        by a comma. The order in which header fields with the same field-name
+        are received is significant to the interpretation of the combined
+        field value.
+        """
+        # XXX The implementation overrides the readheaders() method of
+        # rfc822.Message.  The base class design isn't amenable to
+        # customized behavior here so the method here is a copy of the
+        # base class code with a few small changes.
+
+        self.dict = {}
+        self.unixfrom = ''
+        self.headers = list = []
+        self.status = ''
+        headerseen = ""
+        firstline = 1
+        startofline = unread = tell = None
+        if hasattr(self.fp, 'unread'):
+            unread = self.fp.unread
+        elif self.seekable:
+            tell = self.fp.tell
+        while 1:
+            if tell:
+                try:
+                    startofline = tell()
+                except IOError:
+                    startofline = tell = None
+                    self.seekable = 0
+            line = self.fp.readline()
+            if not line:
+                self.status = 'EOF in headers'
+                break
+            # Skip unix From name time lines
+            if firstline and line.startswith('From '):
+                self.unixfrom = self.unixfrom + line
+                continue
+            firstline = 0
+            if headerseen and line[0] in ' \t':
+                # XXX Not sure if continuation lines are handled properly
+                # for http and/or for repeating headers
+                # It's a continuation line.
+                list.append(line)
+                x = self.dict[headerseen] + "\n " + line.strip()
+                self.addcontinue(headerseen, line.strip())
+                continue
+            elif self.iscomment(line):
+                # It's a comment.  Ignore it.
+                continue
+            elif self.islast(line):
+                # Note! No pushback here!  The delimiter line gets eaten.
+                break
+            headerseen = self.isheader(line)
+            if headerseen:
+                # It's a legal header line, save it.
+                list.append(line)
+                self.addheader(headerseen, line[len(headerseen)+1:].strip())
+                continue
+            else:
+                # It's not a header line; throw it back and stop here.
+                if not self.dict:
+                    self.status = 'No headers'
+                else:
+                    self.status = 'Non-header line where header expected'
+                # Try to undo the read.
+                if unread:
+                    unread(line)
+                elif tell:
+                    self.fp.seek(startofline)
+                else:
+                    self.status = self.status + '; bad seek'
+                break
 
 class HTTPResponse:
-    def __init__(self, sock, debuglevel=0):
+
+    # strict: If true, raise BadStatusLine if the status line can't be
+    # parsed as a valid HTTP/1.0 or 1.1 status line.  By default it is
+    # false because it prvents clients from talking to HTTP/0.9
+    # servers.  Note that a response with a sufficiently corrupted
+    # status line will look like an HTTP/0.9 response.
+
+    # See RFC 2616 sec 19.6 and RFC 1945 sec 6 for details.
+
+    def __init__(self, sock, debuglevel=0, strict=0):
         self.fp = sock.makefile('rb', 0)
         self.debuglevel = debuglevel
+        self.strict = strict
 
         self.msg = None
 
@@ -112,6 +227,7 @@ class HTTPResponse:
         self.will_close = _UNKNOWN      # conn will close at end of response
 
     def _read_status(self):
+        # Initialize with Simple-Response defaults
         line = self.fp.readline()
         if self.debuglevel > 0:
             print "reply:", repr(line)
@@ -122,12 +238,17 @@ class HTTPResponse:
                 [version, status] = line.split(None, 1)
                 reason = ""
             except ValueError:
-                version = "HTTP/0.9"
-                status = "200"
-                reason = ""
-        if version[:5] != 'HTTP/':
-            self.close()
-            raise BadStatusLine(line)
+                # empty version will cause next test to fail and status
+                # will be treated as 0.9 response.
+                version = ""
+        if not version.startswith('HTTP/'):
+            if self.strict:
+                self.close()
+                raise BadStatusLine(line)
+            else:
+                # assume it's a Simple-Response from an 0.9 server
+                self.fp = LineAndFileWrapper(line, self.fp)
+                return "HTTP/0.9", 200, ""
 
         # The status code is a three-digit number
         try:
@@ -138,7 +259,7 @@ class HTTPResponse:
             raise BadStatusLine(line)
         return version, status, reason
 
-    def _begin(self):
+    def begin(self):
         if self.msg is not None:
             # we've already started reading the response
             return
@@ -169,10 +290,11 @@ class HTTPResponse:
 
         if self.version == 9:
             self.chunked = 0
-            self.msg = mimetools.Message(StringIO())
+            self.will_close = 1
+            self.msg = HTTPMessage(StringIO())
             return
 
-        self.msg = mimetools.Message(self.fp, 0)
+        self.msg = HTTPMessage(self.fp, 0)
         if self.debuglevel > 0:
             for hdr in self.msg.headers:
                 print "header:", hdr,
@@ -353,13 +475,16 @@ class HTTPConnection:
     default_port = HTTP_PORT
     auto_open = 1
     debuglevel = 0
+    strict = 0
 
-    def __init__(self, host, port=None):
+    def __init__(self, host, port=None, strict=None):
         self.sock = None
         self.__response = None
         self.__state = _CS_IDLE
-
+        
         self._set_hostport(host, port)
+        if strict is not None:
+            self.strict = strict
 
     def _set_hostport(self, host, port):
         if port is None:
@@ -368,7 +493,7 @@ class HTTPConnection:
                 try:
                     port = int(host[i+1:])
                 except ValueError:
-                    raise InvalidURL, "nonnumeric port: '%s'"%host[i+1:]
+                    raise InvalidURL("nonnumeric port: '%s'" % host[i+1:])
                 host = host[:i]
             else:
                 port = self.default_port
@@ -610,11 +735,12 @@ class HTTPConnection:
             raise ResponseNotReady()
 
         if self.debuglevel > 0:
-            response = self.response_class(self.sock, self.debuglevel)
+            response = self.response_class(self.sock, self.debuglevel,
+                                           strict=self.strict)
         else:
-            response = self.response_class(self.sock)
+            response = self.response_class(self.sock, strict=self.strict)
 
-        response._begin()
+        response.begin()
         assert response.will_close != _UNKNOWN
         self.__state = _CS_IDLE
 
@@ -627,13 +753,59 @@ class HTTPConnection:
 
         return response
 
-class SSLFile:
+# The next several classes are used to define FakeSocket,a socket-like
+# interface to an SSL connection.
+
+# The primary complexity comes from faking a makefile() method.  The
+# standard socket makefile() implementation calls dup() on the socket
+# file descriptor.  As a consequence, clients can call close() on the
+# parent socket and its makefile children in any order.  The underlying
+# socket isn't closed until they are all closed.
+
+# The implementation uses reference counting to keep the socket open
+# until the last client calls close().  SharedSocket keeps track of
+# the reference counting and SharedSocketClient provides an constructor
+# and close() method that call incref() and decref() correctly.
+
+class SharedSocket:
+
+    def __init__(self, sock):
+        self.sock = sock
+        self._refcnt = 0
+
+    def incref(self):
+        self._refcnt += 1
+
+    def decref(self):
+        self._refcnt -= 1
+        assert self._refcnt >= 0
+        if self._refcnt == 0:
+            self.sock.close()
+
+    def __del__(self):
+        self.sock.close()
+
+class SharedSocketClient:
+
+    def __init__(self, shared):
+        self._closed = 0
+        self._shared = shared
+        self._shared.incref()
+        self._sock = shared.sock
+
+    def close(self):
+        if not self._closed:
+            self._shared.decref()
+            self._closed = 1
+            self._shared = None
+
+class SSLFile(SharedSocketClient):
     """File-like object wrapping an SSL socket."""
 
     BUFSIZE = 8192
     
     def __init__(self, sock, ssl, bufsize=None):
-        self._sock = sock
+        SharedSocketClient.__init__(self, sock)
         self._ssl = ssl
         self._buf = ''
         self._bufsize = bufsize or self.__class__.BUFSIZE
@@ -702,30 +874,36 @@ class SSLFile:
             self._buf = all[i:]
             return line
 
-    def close(self):
-        self._sock.close()
+class FakeSocket(SharedSocketClient):
+
+    class _closedsocket:
+        def __getattr__(self, name):
+            raise error(9, 'Bad file descriptor')
 
-class FakeSocket:
     def __init__(self, sock, ssl):
-        self.__sock = sock
-        self.__ssl = ssl
+        sock = SharedSocket(sock)
+        SharedSocketClient.__init__(self, sock)
+        self._ssl = ssl
+
+    def close(self):
+        SharedSocketClient.close(self)
+        self._sock = self.__class__._closedsocket()
 
     def makefile(self, mode, bufsize=None):
         if mode != 'r' and mode != 'rb':
             raise UnimplementedFileMode()
-        return SSLFile(self.__sock, self.__ssl, bufsize)
+        return SSLFile(self._shared, self._ssl, bufsize)
 
     def send(self, stuff, flags = 0):
-        return self.__ssl.write(stuff)
+        return self._ssl.write(stuff)
 
-    def sendall(self, stuff, flags = 0):
-        return self.__ssl.write(stuff)
+    sendall = send
 
     def recv(self, len = 1024, flags = 0):
-        return self.__ssl.read(len)
+        return self._ssl.read(len)
 
     def __getattr__(self, attr):
-        return getattr(self.__sock, attr)
+        return getattr(self._sock, attr)
 
 
 class HTTPSConnection(HTTPConnection):
@@ -733,8 +911,9 @@ class HTTPSConnection(HTTPConnection):
 
     default_port = HTTPS_PORT
 
-    def __init__(self, host, port=None, key_file=None, cert_file=None):
-        HTTPConnection.__init__(self, host, port)
+    def __init__(self, host, port=None, key_file=None, cert_file=None,
+                 strict=None):
+        HTTPConnection.__init__(self, host, port, strict)
         self.key_file = key_file
         self.cert_file = cert_file
 
@@ -760,7 +939,7 @@ class HTTP:
 
     _connection_class = HTTPConnection
 
-    def __init__(self, host='', port=None):
+    def __init__(self, host='', port=None, strict=None):
         "Provide a default host, since the superclass requires one."
 
         # some joker passed 0 explicitly, meaning default port
@@ -770,7 +949,7 @@ class HTTP:
         # Note that we may pass an empty string as the host; this will throw
         # an error when we attempt to connect. Presumably, the client code
         # will call connect before then, with a proper host.
-        self._setup(self._connection_class(host, port))
+        self._setup(self._connection_class(host, port, strict))
 
     def _setup(self, conn):
         self._conn = conn
@@ -850,21 +1029,25 @@ if hasattr(socket, 'ssl'):
 
         _connection_class = HTTPSConnection
 
-        def __init__(self, host='', port=None, **x509):
+        def __init__(self, host='', port=None, key_file=None, cert_file=None,
+                     strict=None):
             # provide a default host, pass the X509 cert info
 
             # urf. compensate for bad input.
             if port == 0:
                 port = None
-            self._setup(self._connection_class(host, port, **x509))
+            self._setup(self._connection_class(host, port, key_file,
+                                               cert_file, strict))
 
             # we never actually use these for anything, but we keep them
             # here for compatibility with post-1.5.2 CVS.
-            self.key_file = x509.get('key_file')
-            self.cert_file = x509.get('cert_file')
+            self.key_file = key_file
+            self.cert_file = cert_file
 
 
 class HTTPException(Exception):
+    # Subclasses that define an __init__ must call Exception.__init__
+    # or define self.args.  Otherwise, str() will fail.
     pass
 
 class NotConnected(HTTPException):
@@ -875,6 +1058,7 @@ class InvalidURL(HTTPException):
 
 class UnknownProtocol(HTTPException):
     def __init__(self, version):
+        self.args = version,
         self.version = version
 
 class UnknownTransferEncoding(HTTPException):
@@ -885,6 +1069,7 @@ class UnimplementedFileMode(HTTPException):
 
 class IncompleteRead(HTTPException):
     def __init__(self, partial):
+        self.args = partial,
         self.partial = partial
 
 class ImproperConnectionState(HTTPException):
@@ -901,21 +1086,77 @@ class ResponseNotReady(ImproperConnectionState):
 
 class BadStatusLine(HTTPException):
     def __init__(self, line):
+        self.args = line,
         self.line = line
 
 # for backwards compatibility
 error = HTTPException
 
+class LineAndFileWrapper:
+    """A limited file-like object for HTTP/0.9 responses."""
+
+    # The status-line parsing code calls readline(), which normally
+    # get the HTTP status line.  For a 0.9 response, however, this is
+    # actually the first line of the body!  Clients need to get a
+    # readable file object that contains that line.
+
+    def __init__(self, line, file):
+        self._line = line
+        self._file = file
+        self._line_consumed = 0
+        self._line_offset = 0
+        self._line_left = len(line)
+
+    def __getattr__(self, attr):
+        return getattr(self._file, attr)
+
+    def _done(self):
+        # called when the last byte is read from the line.  After the
+        # call, all read methods are delegated to the underlying file
+        # obhect.
+        self._line_consumed = 1
+        self.read = self._file.read
+        self.readline = self._file.readline
+        self.readlines = self._file.readlines
+
+    def read(self, amt=None):
+        assert not self._line_consumed and self._line_left
+        if amt is None or amt > self._line_left:
+            s = self._line[self._line_offset:]
+            self._done()
+            if amt is None:
+                return s + self._file.read()
+            else:
+                return s + self._file.read(amt - len(s))                
+        else:
+            assert amt <= self._line_left
+            i = self._line_offset
+            j = i + amt
+            s = self._line[i:j]
+            self._line_offset = j
+            self._line_left -= amt
+            if self._line_left == 0:
+                self._done()
+            return s
+        
+    def readline(self):
+        s = self._line[self._line_offset:]
+        self._done()
+        return s
+
+    def readlines(self, size=None):
+        L = [self._line[self._line_offset:]]
+        self._done()
+        if size is None:
+            return L + self._file.readlines()
+        else:
+            return L + self._file.readlines(size)
 
-#
-# snarfed from httplib.py for now...
-#
 def test():
     """Test this module.
 
-    The test consists of retrieving and displaying the Python
-    home page, along with the error code and error string returned
-    by the www.python.org server.
+    A hodge podge of tests collected here, because they have too many
+    external dependencies for the regular test suite.
     """
 
     import sys
@@ -936,11 +1177,11 @@ def test():
     status, reason, headers = h.getreply()
     print 'status =', status
     print 'reason =', reason
+    print "read", len(h.getfile().read())
     print
     if headers:
         for header in headers.headers: print header.strip()
     print
-    print "read", len(h.getfile().read())
 
     # minimal test that code to extract host from url works
     class HTTP11(HTTP):
@@ -954,22 +1195,57 @@ def test():
     h.close()
 
     if hasattr(socket, 'ssl'):
-        host = 'sourceforge.net'
-        selector = '/projects/python'
-        hs = HTTPS()
-        hs.connect(host)
-        hs.putrequest('GET', selector)
-        hs.endheaders()
-        status, reason, headers = hs.getreply()
-        # XXX why does this give a 302 response?
-        print 'status =', status
-        print 'reason =', reason
-        print
-        if headers:
-            for header in headers.headers: print header.strip()
-        print
-        print "read", len(hs.getfile().read())
-
+        
+        for host, selector in (('sourceforge.net', '/projects/python'),
+                               ('dbserv2.theopalgroup.com', '/mediumfile'),
+                               ('dbserv2.theopalgroup.com', '/smallfile'),
+                               ):
+            print "https://%s%s" % (host, selector)
+            hs = HTTPS()
+            hs.connect(host)
+            hs.putrequest('GET', selector)
+            hs.endheaders()
+            status, reason, headers = hs.getreply()
+            print 'status =', status
+            print 'reason =', reason
+            print "read", len(hs.getfile().read())
+            print
+            if headers:
+                for header in headers.headers: print header.strip()
+            print
+
+    return
+
+    # Test a buggy server -- returns garbled status line.
+    # http://www.yahoo.com/promotions/mom_com97/supermom.html
+    c = HTTPConnection("promotions.yahoo.com")
+    c.set_debuglevel(1)
+    c.connect()
+    c.request("GET", "/promotions/mom_com97/supermom.html")
+    r = c.getresponse()
+    print r.status, r.version
+    lines = r.read().split("\n")
+    print "\n".join(lines[:5])
+
+    c = HTTPConnection("promotions.yahoo.com", strict=1)
+    c.set_debuglevel(1)
+    c.connect()
+    c.request("GET", "/promotions/mom_com97/supermom.html")
+    try:
+        r = c.getresponse()
+    except BadStatusLine, err:
+        print "strict mode failed as expected"
+        print err
+    else:
+        print "XXX strict mode should have failed"
+
+    for strict in 0, 1:
+        h = HTTP(strict=strict)
+        h.connect("promotions.yahoo.com")
+        h.putrequest('GET', "/promotions/mom_com97/supermom.html")
+        h.endheaders()
+        status, reason, headers = h.getreply()
+        assert (strict and status == -1) or status == 200, (strict, status)
 
 if __name__ == '__main__':
     test()
index 6270d8b1aa3d79075cac15afbf21f36752ede47d..09f92fc46a9f40e9f465a0377a1ad625a68902d2 100644 (file)
@@ -8,24 +8,52 @@ class FakeSocket:
 
     def makefile(self, mode, bufsize=None):
         if mode != 'r' and mode != 'rb':
-            raise UnimplementedFileMode()
+            raise httplib.UnimplementedFileMode()
         return StringIO.StringIO(self.text)
 
 # Test HTTP status lines
 
 body = "HTTP/1.1 200 Ok\r\n\r\nText"
 sock = FakeSocket(body)
-resp = httplib.HTTPResponse(sock,1)
-resp._begin()
+resp = httplib.HTTPResponse(sock, 1)
+resp.begin()
 print resp.read()
 resp.close()
 
 body = "HTTP/1.1 400.100 Not Ok\r\n\r\nText"
 sock = FakeSocket(body)
-resp = httplib.HTTPResponse(sock,1)
+resp = httplib.HTTPResponse(sock, 1)
 try:
-    resp._begin()
+    resp.begin()
 except httplib.BadStatusLine:
     print "BadStatusLine raised as expected"
 else:
     print "Expect BadStatusLine"
+
+# Check invalid host_port
+
+for hp in ("www.python.org:abc", "www.python.org:"):
+    try:
+        h = httplib.HTTP(hp)
+    except httplib.InvalidURL:
+        print "InvalidURL raised as expected"
+    else:
+        print "Expect InvalidURL"
+
+# test response with multiple message headers with the same field name.
+text = ('HTTP/1.1 200 OK\r\n'
+        'Set-Cookie: Customer="WILE_E_COYOTE"; Version="1"; Path="/acme"\r\n'
+        'Set-Cookie: Part_Number="Rocket_Launcher_0001"; Version="1";'
+        ' Path="/acme"\r\n'
+        '\r\n'
+        'No body\r\n')
+hdr = ('Customer="WILE_E_COYOTE"; Version="1"; Path="/acme"'
+       ', '
+       'Part_Number="Rocket_Launcher_0001"; Version="1"; Path="/acme"')
+s = FakeSocket(text)
+r = httplib.HTTPResponse(s, 1)
+r.begin()
+cookies = r.getheader("Set-Cookie")
+if cookies != hdr:
+    raise AssertionError, "multiple headers not combined properly"
+