]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-113280: Always close socket if SSLSocket creation failed (GH-114659)
authorSerhiy Storchaka <storchaka@gmail.com>
Sun, 4 Feb 2024 15:28:07 +0000 (17:28 +0200)
committerGitHub <noreply@github.com>
Sun, 4 Feb 2024 15:28:07 +0000 (15:28 +0000)
Co-authored-by: Thomas Grainger <tagrain@gmail.com>
Lib/ssl.py
Lib/test/test_ssl.py
Misc/NEWS.d/next/Library/2024-01-27-20-11-24.gh-issue-113280.CZPQMf.rst [new file with mode: 0644]

index 74a9d2d8fd4fb01c64b3be2940ec11a2e142f176..03d0121891ff4cb4a636ded561fc1b6177860af0 100644 (file)
@@ -994,71 +994,67 @@ class SSLSocket(socket):
         if context.check_hostname and not server_hostname:
             raise ValueError("check_hostname requires server_hostname")
 
+        sock_timeout = sock.gettimeout()
         kwargs = dict(
             family=sock.family, type=sock.type, proto=sock.proto,
             fileno=sock.fileno()
         )
         self = cls.__new__(cls, **kwargs)
         super(SSLSocket, self).__init__(**kwargs)
-        sock_timeout = sock.gettimeout()
         sock.detach()
-
-        self._context = context
-        self._session = session
-        self._closed = False
-        self._sslobj = None
-        self.server_side = server_side
-        self.server_hostname = context._encode_hostname(server_hostname)
-        self.do_handshake_on_connect = do_handshake_on_connect
-        self.suppress_ragged_eofs = suppress_ragged_eofs
-
-        # See if we are connected
+        # Now SSLSocket is responsible for closing the file descriptor.
         try:
-            self.getpeername()
-        except OSError as e:
-            if e.errno != errno.ENOTCONN:
-                raise
-            connected = False
-            blocking = self.getblocking()
-            self.setblocking(False)
+            self._context = context
+            self._session = session
+            self._closed = False
+            self._sslobj = None
+            self.server_side = server_side
+            self.server_hostname = context._encode_hostname(server_hostname)
+            self.do_handshake_on_connect = do_handshake_on_connect
+            self.suppress_ragged_eofs = suppress_ragged_eofs
+
+            # See if we are connected
             try:
-                # We are not connected so this is not supposed to block, but
-                # testing revealed otherwise on macOS and Windows so we do
-                # the non-blocking dance regardless. Our raise when any data
-                # is found means consuming the data is harmless.
-                notconn_pre_handshake_data = self.recv(1)
+                self.getpeername()
             except OSError as e:
-                # EINVAL occurs for recv(1) on non-connected on unix sockets.
-                if e.errno not in (errno.ENOTCONN, errno.EINVAL):
+                if e.errno != errno.ENOTCONN:
                     raise
-                notconn_pre_handshake_data = b''
-            self.setblocking(blocking)
-            if notconn_pre_handshake_data:
-                # This prevents pending data sent to the socket before it was
-                # closed from escaping to the caller who could otherwise
-                # presume it came through a successful TLS connection.
-                reason = "Closed before TLS handshake with data in recv buffer."
-                notconn_pre_handshake_data_error = SSLError(e.errno, reason)
-                # Add the SSLError attributes that _ssl.c always adds.
-                notconn_pre_handshake_data_error.reason = reason
-                notconn_pre_handshake_data_error.library = None
-                try:
-                    self.close()
-                except OSError:
-                    pass
+                connected = False
+                blocking = self.getblocking()
+                self.setblocking(False)
                 try:
-                    raise notconn_pre_handshake_data_error
-                finally:
-                    # Explicitly break the reference cycle.
-                    notconn_pre_handshake_data_error = None
-        else:
-            connected = True
+                    # We are not connected so this is not supposed to block, but
+                    # testing revealed otherwise on macOS and Windows so we do
+                    # the non-blocking dance regardless. Our raise when any data
+                    # is found means consuming the data is harmless.
+                    notconn_pre_handshake_data = self.recv(1)
+                except OSError as e:
+                    # EINVAL occurs for recv(1) on non-connected on unix sockets.
+                    if e.errno not in (errno.ENOTCONN, errno.EINVAL):
+                        raise
+                    notconn_pre_handshake_data = b''
+                self.setblocking(blocking)
+                if notconn_pre_handshake_data:
+                    # This prevents pending data sent to the socket before it was
+                    # closed from escaping to the caller who could otherwise
+                    # presume it came through a successful TLS connection.
+                    reason = "Closed before TLS handshake with data in recv buffer."
+                    notconn_pre_handshake_data_error = SSLError(e.errno, reason)
+                    # Add the SSLError attributes that _ssl.c always adds.
+                    notconn_pre_handshake_data_error.reason = reason
+                    notconn_pre_handshake_data_error.library = None
+                    try:
+                        raise notconn_pre_handshake_data_error
+                    finally:
+                        # Explicitly break the reference cycle.
+                        notconn_pre_handshake_data_error = None
+            else:
+                connected = True
 
-        self.settimeout(sock_timeout)  # Must come after setblocking() calls.
-        self._connected = connected
-        if connected:
-            # create the SSL object
-            try:
+            self.settimeout(sock_timeout)  # Must come after setblocking() calls.
+            self._connected = connected
+            if connected:
+                # create the SSL object
                 self._sslobj = self._context._wrap_socket(
                     self, server_side, self.server_hostname,
                     owner=self, session=self._session,
@@ -1069,9 +1065,12 @@ class SSLSocket(socket):
                         # non-blocking
                         raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
                     self.do_handshake()
-            except (OSError, ValueError):
+        except:
+            try:
                 self.close()
-                raise
+            except OSError:
+                pass
+            raise
         return self
 
     @property
index 3fdfa2960503b8fc126845c7d05421556a5fd683..1b18230d83577dc218747d02d521146cbe229004 100644 (file)
@@ -2206,14 +2206,15 @@ def _test_get_server_certificate(test, host, port, cert=None):
         sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem))
 
 def _test_get_server_certificate_fail(test, host, port):
-    try:
-        pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE)
-    except ssl.SSLError as x:
-        #should fail
-        if support.verbose:
-            sys.stdout.write("%s\n" % x)
-    else:
-        test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
+    with warnings_helper.check_no_resource_warning(test):
+        try:
+            pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE)
+        except ssl.SSLError as x:
+            #should fail
+            if support.verbose:
+                sys.stdout.write("%s\n" % x)
+        else:
+            test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
 
 
 from test.ssl_servers import make_https_server
@@ -3026,6 +3027,16 @@ class ThreadedTests(unittest.TestCase):
                                      server_hostname="python.example.org") as s:
                 with self.assertRaises(ssl.CertificateError):
                     s.connect((HOST, server.port))
+        with ThreadedEchoServer(context=server_context, chatty=True) as server:
+            with warnings_helper.check_no_resource_warning(self):
+                with self.assertRaises(UnicodeError):
+                    context.wrap_socket(socket.socket(),
+                            server_hostname='.pythontest.net')
+        with ThreadedEchoServer(context=server_context, chatty=True) as server:
+            with warnings_helper.check_no_resource_warning(self):
+                with self.assertRaises(UnicodeDecodeError):
+                    context.wrap_socket(socket.socket(),
+                            server_hostname=b'k\xf6nig.idn.pythontest.net')
 
     def test_wrong_cert_tls12(self):
         """Connecting when the server rejects the client's certificate
@@ -4983,7 +4994,8 @@ class TestPreHandshakeClose(unittest.TestCase):
             self.assertIsNone(wrap_error.library, msg="attr must exist")
         finally:
             # gh-108342: Explicitly break the reference cycle
-            wrap_error = None
+            with warnings_helper.check_no_resource_warning(self):
+                wrap_error = None
             server = None
 
     def test_https_client_non_tls_response_ignored(self):
@@ -5032,7 +5044,8 @@ class TestPreHandshakeClose(unittest.TestCase):
         # socket; that fails if the connection is broken. It may seem pointless
         # to test this. It serves as an illustration of something that we never
         # want to happen... properly not happening.
-        with self.assertRaises(OSError):
+        with warnings_helper.check_no_resource_warning(self), \
+                self.assertRaises(OSError):
             connection.request("HEAD", "/test", headers={"Host": "localhost"})
             response = connection.getresponse()
 
diff --git a/Misc/NEWS.d/next/Library/2024-01-27-20-11-24.gh-issue-113280.CZPQMf.rst b/Misc/NEWS.d/next/Library/2024-01-27-20-11-24.gh-issue-113280.CZPQMf.rst
new file mode 100644 (file)
index 0000000..3dcdbcf
--- /dev/null
@@ -0,0 +1,2 @@
+Fix a leak of open socket in rare cases when error occurred in
+:class:`ssl.SSLSocket` creation.