From: kimbo Date: Mon, 30 Dec 2019 21:02:33 +0000 (-0700) Subject: req'd session param to https() for connection reuse X-Git-Tag: v2.0.0rc1~342^2~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4d68f1c670c7c7b8ea2c75e69b92019d120bd43d;p=thirdparty%2Fdnspython.git req'd session param to https() for connection reuse --- diff --git a/dns/query.py b/dns/query.py index 23711453..867acab3 100644 --- a/dns/query.py +++ b/dns/query.py @@ -210,7 +210,7 @@ def _destination_and_source(af, where, port, source, source_port): return (af, destination, source) -def https(q, where, timeout=None, port=443, path='/dns-query', post=True, +def https(q, where, session, timeout=None, port=443, path='/dns-query', post=True, verify=True, source=None, source_port=0, one_rr_per_rrset=False, ignore_trailing=False): """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -221,6 +221,9 @@ def https(q, where, timeout=None, port=443, path='/dns-query', post=True, address is given, the URL will be constructed using the following schema: https::/. + *session*, a ``requests.session.Session``, the session to use to send the + queries. This argument is required to allow for connection reuse. + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query times out. If ``None``, the default, wait forever. @@ -257,12 +260,11 @@ def https(q, where, timeout=None, port=443, path='/dns-query', post=True, url = 'https://{}:{}{}'.format(where, port, path) except ValueError: url = where - session = requests.sessions.Session() + # session = requests.sessions.Session() try: # set source port and source address # see https://github.com/requests/toolbelt/blob/master/requests_toolbelt/adapters/source.py - session.mount('http://', SourceAddressAdapter(source)) - session.mount('https://', SourceAddressAdapter(source)) + session.mount(url, SourceAddressAdapter(source)) # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples if post: diff --git a/dns/query.pyi b/dns/query.pyi index 7c6f09c5..faedf8ce 100644 --- a/dns/query.pyi +++ b/dns/query.pyi @@ -1,5 +1,6 @@ from typing import Optional, Union, Dict, Generator, Any from . import tsig, rdatatype, rdataclass, name, message +from requests.sessions import Session try: import ssl @@ -7,7 +8,7 @@ except ImportError: class ssl(object): SSLContext = {} -def https(q : message.Message, where: str, timeout : Optional[float] = None, port : Optional[int] = 443, path : Optional[str] = '/dns-query', post : Optional[bool] = True, +def https(q : message.Message, where: str, session: Session, timeout : Optional[float] = None, port : Optional[int] = 443, path : Optional[str] = '/dns-query', post : Optional[bool] = True, verify : Optional[bool] = True, source : Optional[str] = None, source_port : Optional[int] = 0, one_rr_per_rrset : Optional[bool] = False, ignore_trailing : Optional[bool] = False) -> message.Message: pass diff --git a/tests/test_doh.py b/tests/test_doh.py index ded2f7ff..5e929764 100644 --- a/tests/test_doh.py +++ b/tests/test_doh.py @@ -17,6 +17,8 @@ import unittest import random +import requests + import dns.query import dns.rdatatype import dns.message @@ -27,18 +29,22 @@ KNOWN_ANYCAST_DOH_RESOLVER_URLS = ['https://cloudflare-dns.com/dns-query', 'https://dns11.quad9.net/dns-query'] class DNSOverHTTPSTestCase(unittest.TestCase): - nameserver_ip = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_IPS) + def setUp(self): + self.session = requests.sessions.Session() + + def tearDown(self): + self.session.close() def test_get_request(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_url, post=False) + r = dns.query.https(q, nameserver_url, self.session, post=False) self.assertTrue(q.is_response(r)) def test_post_request(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_url, post=True) + r = dns.query.https(q, nameserver_url, self.session, post=True) self.assertTrue(q.is_response(r)) def test_build_url_from_ip(self): @@ -46,7 +52,7 @@ class DNSOverHTTPSTestCase(unittest.TestCase): q = dns.message.make_query('example.com.', dns.rdatatype.A) # For some reason Google's DNS over HTTPS fails when you POST to https://8.8.8.8/dns-query # So we're just going to do GET requests here - r = dns.query.https(q, nameserver_ip, post=False) + r = dns.query.https(q, nameserver_ip, self.session, post=False) self.assertTrue(q.is_response(r)) if __name__ == '__main__':