import requests
from requests_toolbelt.adapters.source import SourceAddressAdapter
from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter
- have_doh = True
+ _have_requests = True
except ImportError: # pragma: no cover
- have_doh = False
+ _have_requests = False
+
+try:
+ import httpx
+ _have_httpx = True
+except ImportError: # pragma: no cover
+ _have_httpx = False
+
+have_doh = _have_requests or _have_httpx
try:
import ssl
if not have_doh:
raise NoDOH # pragma: no cover
+ _httpx_ok = True
+
wire = q.to_wire()
(af, _, source) = _destination_and_source(where, port, source, source_port,
False)
transport_adapter = None
+ transport = None
headers = {
"accept": "application/dns-message"
}
elif af == socket.AF_INET6:
url = 'https://[{}]:{}{}'.format(where, port, path)
elif bootstrap_address is not None:
+ _httpx_ok = False
split_url = urllib.parse.urlsplit(where)
headers['Host'] = split_url.hostname
url = where.replace(split_url.hostname, bootstrap_address)
- transport_adapter = HostHeaderSSLAdapter()
+ if _have_requests:
+ transport_adapter = HostHeaderSSLAdapter()
else:
url = where
if source is not None:
# set source port and source address
- transport_adapter = SourceAddressAdapter(source)
+ if _have_httpx:
+ if source_port == 0:
+ transport = httpx.HTTPTransport(local_address=source[0])
+ else:
+ _httpx_ok = False
+ if _have_requests:
+ transport_adapter = SourceAddressAdapter(source)
+
+ if not _httpx_ok and not _have_requests:
+ raise NoDOH('Cannot use httpx for this operation, and '
+ 'requests is not available.')
with contextlib.ExitStack() as stack:
+ if session:
+ if _have_httpx:
+ _is_httpx = isinstance(session, httpx.Client)
+ else:
+ _is_httpx = False
+ if _is_httpx and not _httpx_ok:
+ # we can't use this session
+ session = None
if not session:
- session = stack.enter_context(requests.sessions.Session())
+ if _have_httpx and _httpx_ok:
+ _is_httpx = True
+ session = stack.enter_context(httpx.Client(http1=True,
+ http2=True,
+ verify=verify,
+ transport=transport))
+ else:
+ _is_httpx = False
+ session = stack.enter_context(requests.sessions.Session())
if transport_adapter:
session.mount(url, transport_adapter)
"content-type": "application/dns-message",
"content-length": str(len(wire))
})
- response = session.post(url, headers=headers, data=wire,
- timeout=timeout, verify=verify)
+ if _is_httpx:
+ response = session.post(url, headers=headers, content=wire,
+ timeout=timeout)
+ else:
+ response = session.post(url, headers=headers, data=wire,
+ timeout=timeout, verify=verify)
else:
wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
- response = session.get(url, headers=headers,
- timeout=timeout, verify=verify,
- params={"dns": wire})
+ if _is_httpx:
+ wire = wire.decode() # httpx does a repr() if we give it bytes
+ response = session.get(url, headers=headers,
+ timeout=timeout,
+ params={"dns": wire})
+ else:
+ response = session.get(url, headers=headers,
+ timeout=timeout, verify=verify,
+ params={"dns": wire})
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes
import dns.rdatatype
import dns.resolver
-if dns.query.have_doh:
+if dns.query._have_requests:
import requests
from requests.exceptions import SSLError
+if dns.query._have_httpx:
+ import httpx
+
# Probe for IPv4 and IPv6
resolver_v4_addresses = []
resolver_v6_addresses = []
except socket.gaierror:
_network_available = False
-@unittest.skipUnless(dns.query.have_doh and _network_available,
+
+@unittest.skipUnless(dns.query._have_requests and _network_available,
"Python requests cannot be imported; no DNS over HTTPS (DOH)")
-class DNSOverHTTPSTestCase(unittest.TestCase):
+class DNSOverHTTPSTestCaseRequests(unittest.TestCase):
def setUp(self):
self.session = requests.sessions.Session()
self.assertTrue('8.8.4.4' in seen)
+@unittest.skipUnless(dns.query._have_httpx and _network_available,
+ "Python httpx cannot be imported; no DNS over HTTPS (DOH)")
+class DNSOverHTTPSTestCaseHttpx(unittest.TestCase):
+ def setUp(self):
+ self.session = httpx.Client(http1=True, http2=True, verify=True)
+
+ def tearDown(self):
+ self.session.close()
+
+ def test_get_request(self):
+ nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
+ q = dns.message.make_query('example.com.', dns.rdatatype.A)
+ r = dns.query.https(q, nameserver_url, session=self.session, post=False,
+ timeout=4)
+ self.assertTrue(q.is_response(r))
+
+ def test_post_request(self):
+ nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
+ q = dns.message.make_query('example.com.', dns.rdatatype.A)
+ r = dns.query.https(q, nameserver_url, session=self.session, post=True,
+ timeout=4)
+ self.assertTrue(q.is_response(r))
+
+ def test_build_url_from_ip(self):
+ self.assertTrue(resolver_v4_addresses or resolver_v6_addresses)
+ if resolver_v4_addresses:
+ nameserver_ip = random.choice(resolver_v4_addresses)
+ q = dns.message.make_query('example.com.', dns.rdatatype.A)
+ # For some reason Google's DNS over HTTPS fails when you POST to
+ # https://8.8.8.8/dns-query
+ # So we're just going to do GET requests here
+ r = dns.query.https(q, nameserver_ip, session=self.session,
+ post=False, timeout=4)
+
+ self.assertTrue(q.is_response(r))
+ if resolver_v6_addresses:
+ nameserver_ip = random.choice(resolver_v6_addresses)
+ q = dns.message.make_query('example.com.', dns.rdatatype.A)
+ r = dns.query.https(q, nameserver_ip, session=self.session,
+ post=False, timeout=4)
+ self.assertTrue(q.is_response(r))
+
+ def test_bootstrap_address(self):
+ # We test this to see if v4 is available
+ if resolver_v4_addresses:
+ ip = '185.228.168.168'
+ invalid_tls_url = 'https://{}/doh/family-filter/'.format(ip)
+ valid_tls_url = 'https://doh.cleanbrowsing.org/doh/family-filter/'
+ q = dns.message.make_query('example.com.', dns.rdatatype.A)
+ # make sure CleanBrowsing's IP address will fail TLS certificate
+ # check
+ with self.assertRaises(httpx.ConnectError):
+ dns.query.https(q, invalid_tls_url, session=self.session,
+ timeout=4)
+ # use host header
+ r = dns.query.https(q, valid_tls_url, session=self.session,
+ bootstrap_address=ip, timeout=4)
+ self.assertTrue(q.is_response(r))
+
+ def test_new_session(self):
+ nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
+ q = dns.message.make_query('example.com.', dns.rdatatype.A)
+ r = dns.query.https(q, nameserver_url, timeout=4)
+ self.assertTrue(q.is_response(r))
+
+ def test_resolver(self):
+ res = dns.resolver.Resolver(configure=False)
+ res.nameservers = ['https://dns.google/dns-query']
+ answer = res.resolve('dns.google', 'A')
+ seen = set([rdata.address for rdata in answer])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
+
if __name__ == '__main__':
unittest.main()