]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Use context managers in test_ssl to simplify test writing.
authorAntoine Pitrou <solipsis@pitrou.net>
Wed, 21 Dec 2011 15:54:45 +0000 (16:54 +0100)
committerAntoine Pitrou <solipsis@pitrou.net>
Wed, 21 Dec 2011 15:54:45 +0000 (16:54 +0100)
1  2 
Lib/test/test_ssl.py

index 1960e143473134764aec6e409541aaa7ec27f34e,e9fbc8afebb9e1fe5e53b13aa382d1a9291319d8..d549799396071402d22925c6d49a749a1023f2c6
@@@ -1189,15 -1094,7 +1198,12 @@@ else
              if connectionchatty:
                  if support.verbose:
                      sys.stdout.write(" client:  closing connection.\n")
 +            stats = {
 +                'compression': s.compression(),
 +                'cipher': s.cipher(),
 +            }
              s.close()
-         finally:
-             server.stop()
-             server.join()
 +            return stats
  
      def try_protocol_combo(server_protocol, client_protocol, expect_success,
                             certsreqs=None, server_options=0, client_options=0):
                              )
                          # consume data
                          s.read()
 +
 +                # Make sure sendmsg et al are disallowed to avoid
 +                # inadvertent disclosure of data and/or corruption
 +                # of the encrypted data stream
 +                self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
 +                self.assertRaises(NotImplementedError, s.recvmsg, 100)
 +                self.assertRaises(NotImplementedError,
 +                                  s.recvmsg_into, bytearray(100))
 +
                  s.write(b"over\n")
                  s.close()
-             finally:
-                 server.stop()
-                 server.join()
  
          def test_handshake_timeout(self):
              # Issue #5103: SSL handshake must respect the socket timeout
                  t.join()
                  server.close()
  
-             flag = threading.Event()
-             server.start(flag)
-             # wait for it to start
-             flag.wait()
-             # try to connect
-             s = ssl.wrap_socket(socket.socket(),
-                                 server_side=False,
-                                 certfile=CERTFILE,
-                                 ca_certs=CERTFILE,
-                                 cert_reqs=ssl.CERT_NONE,
-                                 ssl_version=ssl.PROTOCOL_TLSv1)
-             s.connect((HOST, server.port))
-             try:
 +        @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
 +                             "'tls-unique' channel binding not available")
 +        def test_tls_unique_channel_binding(self):
 +            """Test tls-unique channel binding."""
 +            if support.verbose:
 +                sys.stdout.write("\n")
 +
 +            server = ThreadedEchoServer(CERTFILE,
 +                                        certreqs=ssl.CERT_NONE,
 +                                        ssl_version=ssl.PROTOCOL_TLSv1,
 +                                        cacerts=CERTFILE,
 +                                        chatty=True,
 +                                        connectionchatty=False)
-             finally:
-                 server.stop()
-                 server.join()
++            with server:
++                s = ssl.wrap_socket(socket.socket(),
++                                    server_side=False,
++                                    certfile=CERTFILE,
++                                    ca_certs=CERTFILE,
++                                    cert_reqs=ssl.CERT_NONE,
++                                    ssl_version=ssl.PROTOCOL_TLSv1)
++                s.connect((HOST, server.port))
 +                # get the data
 +                cb_data = s.get_channel_binding("tls-unique")
 +                if support.verbose:
 +                    sys.stdout.write(" got channel binding data: {0!r}\n"
 +                                     .format(cb_data))
 +
 +                # check if it is sane
 +                self.assertIsNotNone(cb_data)
 +                self.assertEqual(len(cb_data), 12) # True for TLSv1
 +
 +                # and compare with the peers version
 +                s.write(b"CB tls-unique\n")
 +                peer_data_repr = s.read().strip()
 +                self.assertEqual(peer_data_repr,
 +                                 repr(cb_data).encode("us-ascii"))
 +                s.close()
 +
 +                # now, again
 +                s = ssl.wrap_socket(socket.socket(),
 +                                    server_side=False,
 +                                    certfile=CERTFILE,
 +                                    ca_certs=CERTFILE,
 +                                    cert_reqs=ssl.CERT_NONE,
 +                                    ssl_version=ssl.PROTOCOL_TLSv1)
 +                s.connect((HOST, server.port))
 +                new_cb_data = s.get_channel_binding("tls-unique")
 +                if support.verbose:
 +                    sys.stdout.write(" got another channel binding data: {0!r}\n"
 +                                     .format(new_cb_data))
 +                # is it really unique
 +                self.assertNotEqual(cb_data, new_cb_data)
 +                self.assertIsNotNone(cb_data)
 +                self.assertEqual(len(cb_data), 12) # True for TLSv1
 +                s.write(b"CB tls-unique\n")
 +                peer_data_repr = s.read().strip()
 +                self.assertEqual(peer_data_repr,
 +                                 repr(new_cb_data).encode("us-ascii"))
 +                s.close()
 +
 +        def test_compression(self):
 +            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
 +            context.load_cert_chain(CERTFILE)
 +            stats = server_params_test(context, context,
 +                                       chatty=True, connectionchatty=True)
 +            if support.verbose:
 +                sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
 +            self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
 +
 +        @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
 +                             "ssl.OP_NO_COMPRESSION needed for this test")
 +        def test_compression_disabled(self):
 +            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
 +            context.load_cert_chain(CERTFILE)
 +            context.options |= ssl.OP_NO_COMPRESSION
 +            stats = server_params_test(context, context,
 +                                       chatty=True, connectionchatty=True)
 +            self.assertIs(stats['compression'], None)
 +
  
  def test_main(verbose=False):
      if support.verbose: